1
0
mirror of https://github.com/netbirdio/netbird.git synced 2025-03-10 04:38:21 +01:00

wip: refactor get account in peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-28 17:47:54 +03:00
parent 313e158e20
commit 9bf0bf4843
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
7 changed files with 176 additions and 143 deletions

View File

@ -94,7 +94,8 @@ type AccountManager interface {
GetUserByID(ctx context.Context, id string) (*User, error) GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@ -105,7 +106,6 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@ -116,6 +116,7 @@ type AccountManager interface {
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error)
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
@ -149,7 +150,7 @@ type AccountManager interface {
GetIdpManager() idp.Manager GetIdpManager() idp.Manager
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error

View File

@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil return peerToReturn, nil
} }
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
groupsInfo := toGroupsInfo(account.Groups, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, err, w)
return
}
groupsInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
} }
} }
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w) h.deletePeer(r.Context(), accountID, userID, peerID, w)
return return
case http.MethodGet, http.MethodPut: case http.MethodGet:
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) h.getPeer(r.Context(), accountID, peerID, userID, w)
if err != nil {
util.WriteError(r.Context(), err, w)
return return
} case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
return return
default: default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
return return
} }
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers)) respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range account.Peers { for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
} }
validPeersMap, err := h.accountManager.GetValidatedPeers(account) validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
} }
} }
dnsDomain := h.accountManager.GetDNSDomain() validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil { if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w) util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return return
} }
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
} }
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum groupsInfo := make([]api.GroupMinimum, 0, len(groups))
groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
_, ok := groupsChecked[group.ID] groupsInfo = append(groupsInfo, api.GroupMinimum{
if ok {
continue
}
groupsChecked[group.ID] = struct{}{}
for _, pk := range group.Peers {
if pk == peerID {
info := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
PeersCount: len(group.Peers), PeersCount: len(group.Peers),
} })
groupsInfo = append(groupsInfo, info)
break
}
}
} }
return groupsInfo return groupsInfo
} }

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@ -78,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
return true, nil return true, nil
} }
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
} }

View File

@ -31,7 +31,8 @@ type MockAccountManager struct {
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@ -47,6 +48,7 @@ type MockAccountManager struct {
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
@ -56,7 +58,6 @@ type MockAccountManager struct {
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@ -123,7 +124,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me") panic("implement me")
} }
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeers := make(map[string]struct{}) approvedPeers := make(map[string]struct{})
for id := range account.Peers { for id := range account.Peers {
approvedPeers[id] = struct{}{} approvedPeers[id] = struct{}{}
@ -425,14 +431,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) (
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
} }
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if am.UpdatePeerSSHKeyFunc != nil {
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
}
// UpdatePeer mocks UpdatePeerFunc function of the account manager // UpdatePeer mocks UpdatePeerFunc function of the account manager
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if am.UpdatePeerFunc != nil { if am.UpdatePeerFunc != nil {
@ -618,12 +616,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
} }
// GetPeers mocks GetPeers of the AccountManager interface // GetUserPeers mocks GetUserPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetPeersFunc != nil { if am.GetUserPeersFunc != nil {
return am.GetPeersFunc(ctx, accountID, userID) return am.GetUserPeersFunc(ctx, accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented")
} }
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface // GetDNSDomain mocks GetDNSDomain of the AccountManager interface
@ -832,3 +830,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
} }
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
if am.GetPeerGroupsFunc != nil {
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}
// ListPeers mocks ListPeers of the AccountManager interface
func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.ListPeersFunc != nil {
return am.ListPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented")
}

View File

@ -4,10 +4,12 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -47,9 +49,23 @@ type PeerLogin struct {
ConnectionIP net.IP ConnectionIP net.IP
} }
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // ListPeers returns a list of peers under the given account.
func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
}
// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -60,7 +76,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -585,7 +601,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -672,7 +688,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} }
validPeersMap, err := am.GetValidatedPeers(account) validPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -847,7 +863,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, emptyMap, nil, nil return peer, emptyMap, nil, nil
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -914,92 +930,53 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false return false
} }
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
}
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
if peer.SSHKey == sshKey {
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
// trigger network map update
am.updateAccountPeers(ctx, account)
return nil
}
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
} }
peer := account.GetPeer(peerID) peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) return nil, err
} }
// if admin or user owns this peer, return peer // if admin or user owns this peer, return peer
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { if user.IsAdminOrServiceUser() || peer.UserID == userID {
return peer, nil return peer, nil
} }
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := account.FindUserPeers(userID) userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
@ -1024,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
peers := account.GetPeers() peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return return
@ -1196,6 +1173,23 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
return peers, nil return peers, nil
} }
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroups := make([]*nbgroup.Group, 0)
for _, group := range groups {
if slices.Contains(group.Peers, peerID) {
peerGroups = append(peerGroups, group)
}
}
return peerGroups, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} { func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels)) labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels { for _, label := range existingLabels {

View File

@ -561,7 +561,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer) assert.NotNil(t, peer)
} }
func TestDefaultAccountManager_GetPeers(t *testing.T) { func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
role UserRole role UserRole
@ -697,7 +697,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return return
} }
peers, err := manager.GetPeers(context.Background(), accountID, someUser) peers, err := manager.GetUserPeers(context.Background(), accountID, someUser)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -822,9 +822,9 @@ func BenchmarkGetPeers(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID) _, err := manager.GetUserPeers(context.Background(), accountID, userID)
if err != nil { if err != nil {
b.Fatalf("GetPeers failed: %v", err) b.Fatalf("GetUserPeers failed: %v", err)
} }
} }
}) })