wip: refactor get account in peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-28 17:47:54 +03:00
parent 313e158e20
commit 9bf0bf4843
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
7 changed files with 176 additions and 143 deletions

View File

@ -94,7 +94,8 @@ type AccountManager interface {
GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@ -105,7 +106,6 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@ -116,6 +116,7 @@ type AccountManager interface {
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error)
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
@ -149,7 +150,7 @@ type AccountManager interface {
GetIdpManager() idp.Manager
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error

View File

@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil
}
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
util.WriteError(ctx, err, w)
return
@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
}
dnsDomain := h.accountManager.GetDNSDomain()
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, err, w)
return
}
groupsInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
}
}
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil {
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(account)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w)
return
case http.MethodGet, http.MethodPut:
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
case http.MethodGet:
h.getPeer(r.Context(), accountID, peerID, userID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range account.Peers {
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
}
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
groupsInfo := make([]api.GroupMinimum, 0, len(groups))
for _, group := range groups {
_, ok := groupsChecked[group.ID]
if ok {
continue
}
groupsChecked[group.ID] = struct{}{}
for _, pk := range group.Peers {
if pk == peerID {
info := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
groupsInfo = append(groupsInfo, info)
break
}
}
groupsInfo = append(groupsInfo, api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
})
}
return groupsInfo
}

View File

@ -4,6 +4,8 @@ import (
"context"
"errors"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
@ -78,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
return true, nil
}
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
}

View File

@ -31,7 +31,8 @@ type MockAccountManager struct {
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@ -47,6 +48,7 @@ type MockAccountManager struct {
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
@ -56,7 +58,6 @@ type MockAccountManager struct {
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@ -123,7 +124,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
@ -425,14 +431,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
}
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if am.UpdatePeerSSHKeyFunc != nil {
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
}
// UpdatePeer mocks UpdatePeerFunc function of the account manager
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if am.UpdatePeerFunc != nil {
@ -618,12 +616,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetPeersFunc != nil {
return am.GetPeersFunc(ctx, accountID, userID)
// GetUserPeers mocks GetUserPeers of the AccountManager interface
func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetUserPeersFunc != nil {
return am.GetUserPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented")
}
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface
@ -832,3 +830,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
}
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
if am.GetPeerGroupsFunc != nil {
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}
// ListPeers mocks ListPeers of the AccountManager interface
func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.ListPeersFunc != nil {
return am.ListPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented")
}

View File

@ -4,10 +4,12 @@ import (
"context"
"fmt"
"net"
"slices"
"strings"
"sync"
"time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@ -47,9 +49,23 @@ type PeerLogin struct {
ConnectionIP net.IP
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// ListPeers returns a list of peers under the given account.
func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
}
// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
@ -60,7 +76,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, err
}
@ -585,7 +601,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
@ -672,7 +688,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
am.updateAccountPeers(ctx, account)
}
validPeersMap, err := am.GetValidatedPeers(account)
validPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
return nil, nil, nil, err
}
@ -847,7 +863,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, emptyMap, nil, nil
}
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
return nil, nil, nil, err
}
@ -914,92 +930,53 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
}
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
if peer.SSHKey == sshKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
// trigger network map update
am.updateAccountPeers(ctx, account)
return nil
}
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID)
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
// if admin or user owns this peer, return peer
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID {
if user.IsAdminOrServiceUser() || peer.UserID == userID {
return peer, nil
}
// it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well.
userPeers, err := account.FindUserPeers(userID)
userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil {
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers {
@ -1024,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return
@ -1196,6 +1173,23 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
return peers, nil
}
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroups := make([]*nbgroup.Group, 0)
for _, group := range groups {
if slices.Contains(group.Peers, peerID) {
peerGroups = append(peerGroups, group)
}
}
return peerGroups, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {

View File

@ -561,7 +561,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer)
}
func TestDefaultAccountManager_GetPeers(t *testing.T) {
func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
testCases := []struct {
name string
role UserRole
@ -697,7 +697,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
peers, err := manager.GetPeers(context.Background(), accountID, someUser)
peers, err := manager.GetUserPeers(context.Background(), accountID, someUser)
if err != nil {
t.Fatal(err)
return
@ -822,9 +822,9 @@ func BenchmarkGetPeers(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID)
_, err := manager.GetUserPeers(context.Background(), accountID, userID)
if err != nil {
b.Fatalf("GetPeers failed: %v", err)
b.Fatalf("GetUserPeers failed: %v", err)
}
}
})