mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-02 03:19:34 +01:00
wip: refactor get account in peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
313e158e20
commit
9bf0bf4843
@ -94,7 +94,8 @@ type AccountManager interface {
|
||||
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@ -105,7 +106,6 @@ type AccountManager interface {
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||
@ -116,6 +116,7 @@ type AccountManager interface {
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error)
|
||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
@ -149,7 +150,7 @@ type AccountManager interface {
|
||||
GetIdpManager() idp.Manager
|
||||
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
|
@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
return peerToReturn, nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
groupsInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
return
|
||||
case http.MethodGet, http.MethodPut:
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||
} else {
|
||||
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||
}
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), accountID, peerID, userID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsChecked := make(map[string]struct{})
|
||||
func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
|
||||
groupsInfo := make([]api.GroupMinimum, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
groupsChecked[group.ID] = struct{}{}
|
||||
for _, pk := range group.Peers {
|
||||
if pk == peerID {
|
||||
info := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
}
|
||||
groupsInfo = append(groupsInfo, info)
|
||||
break
|
||||
}
|
||||
}
|
||||
groupsInfo = append(groupsInfo, api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
})
|
||||
}
|
||||
return groupsInfo
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@ -78,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
|
||||
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
|
||||
}
|
||||
|
@ -31,7 +31,8 @@ type MockAccountManager struct {
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@ -47,6 +48,7 @@ type MockAccountManager struct {
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||
@ -56,7 +58,6 @@ type MockAccountManager struct {
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@ -123,7 +124,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
account, err := am.GetAccountFunc(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
approvedPeers[id] = struct{}{}
|
||||
@ -425,14 +431,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if am.UpdatePeerSSHKeyFunc != nil {
|
||||
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeer mocks UpdatePeerFunc function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
if am.UpdatePeerFunc != nil {
|
||||
@ -618,12 +616,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl
|
||||
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
|
||||
}
|
||||
|
||||
// GetPeers mocks GetPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.GetPeersFunc != nil {
|
||||
return am.GetPeersFunc(ctx, accountID, userID)
|
||||
// GetUserPeers mocks GetUserPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.GetUserPeersFunc != nil {
|
||||
return am.GetUserPeersFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented")
|
||||
}
|
||||
|
||||
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface
|
||||
@ -832,3 +830,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
|
||||
if am.GetPeerGroupsFunc != nil {
|
||||
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListPeers mocks ListPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.ListPeersFunc != nil {
|
||||
return am.ListPeersFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented")
|
||||
}
|
||||
|
@ -4,10 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -47,9 +49,23 @@ type PeerLogin struct {
|
||||
ConnectionIP net.IP
|
||||
}
|
||||
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// ListPeers returns a list of peers under the given account.
|
||||
func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -60,7 +76,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -585,7 +601,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -672,7 +688,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
validPeersMap, err := am.GetValidatedPeers(account)
|
||||
validPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -847,7 +863,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -914,92 +930,53 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if peer.SSHKey == sshKey {
|
||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.SSHKey = sshKey
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked {
|
||||
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if admin or user owns this peer, return peer
|
||||
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID {
|
||||
if user.IsAdminOrServiceUser() || peer.UserID == userID {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// it is also possible that user doesn't own the peer but some of his peers have access to it,
|
||||
// this is a valid case, show the peer as well.
|
||||
userPeers, err := account.FindUserPeers(userID)
|
||||
userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errGetAccountFmt, err)
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
for _, aclPeer := range aclPeers {
|
||||
@ -1024,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
|
||||
return
|
||||
@ -1196,6 +1173,23 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// GetPeerGroups returns groups that the peer is part of.
|
||||
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerGroups := make([]*nbgroup.Group, 0)
|
||||
for _, group := range groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroups = append(peerGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return peerGroups, nil
|
||||
}
|
||||
|
||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||
for _, label := range existingLabels {
|
||||
|
@ -561,7 +561,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
assert.NotNil(t, peer)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
role UserRole
|
||||
@ -697,7 +697,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := manager.GetPeers(context.Background(), accountID, someUser)
|
||||
peers, err := manager.GetUserPeers(context.Background(), accountID, someUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -822,9 +822,9 @@ func BenchmarkGetPeers(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := manager.GetPeers(context.Background(), accountID, userID)
|
||||
_, err := manager.GetUserPeers(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetPeers failed: %v", err)
|
||||
b.Fatalf("GetUserPeers failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user