diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e4..235e744b3 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) +func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { util.WriteError(ctx, err, w) return @@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := toGroupsInfo(account.Groups, peer.ID) - - validPeers, err := h.accountManager.GetValidatedPeers(account) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(ctx, err, w) + return + } + groupsInfo := toGroupsInfo(peerGroups) + + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } @@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { case http.MethodDelete: h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodGet, http.MethodPut: - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if r.Method == http.MethodGet { - h.getPeer(r.Context(), account, peerID, userID, w) - } else { - h.updatePeer(r.Context(), account, userID, peerID, w, r) - } + case http.MethodGet: + h.getPeer(r.Context(), accountID, peerID, userID, w) + return + case http.MethodPut: + h.updatePeer(r.Context(), accountID, userID, peerID, w, r) return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(account) + validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request } } - dnsDomain := h.accountManager.GetDNSDomain() - - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) + dnsDomain := h.accountManager.GetDNSDomain() + + customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) @@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum - groupsChecked := make(map[string]struct{}) +func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { + groupsInfo := make([]api.GroupMinimum, 0, len(groups)) for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } + groupsInfo = append(groupsInfo, api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + }) } return groupsInfo } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dd49c03b8..9279fc536 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -39,6 +39,68 @@ const ( ) func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &server.Policy{ + ID: "policy", + AccountID: "test_id", + Name: "policy", + Enabled: true, + Rules: []*server.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, + }, + } + + srvUser := server.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &server.Account{ + Id: "test_id", + Domain: "hotmail.com", + Peers: peersMap, + Users: map[string]*server.User{ + adminUser: server.NewAdminUser(adminUser), + regularUser: server.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*nbgroup.Group{ + "group1": { + ID: "group1", + AccountID: "test_id", + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, + }, + Settings: &server.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + Policies: []*server.Policy{policy}, + Network: &server.Network{ + Identifier: "ciclqisab2ss43jdn8q0", + Net: net.IPNet{ + IP: net.ParseIP("100.67.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + Serial: 51, + }, + } + return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -67,74 +129,31 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + peersID := make([]string, len(peers)) + for _, peer := range peers { + peersID = append(peersID, peer.ID) + } + return []*nbgroup.Group{ + { + ID: "group1", + AccountID: accountID, + Name: "group1", + Issued: "api", + Peers: peersID, + }, + }, nil + }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, + GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + return account, nil + }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - peersMap := make(map[string]*nbpeer.Peer) - for _, peer := range peers { - peersMap[peer.ID] = peer.Copy() - } - - policy := &server.Policy{ - ID: "policy", - AccountID: accountID, - Name: "policy", - Enabled: true, - Rules: []*server.PolicyRule{ - { - ID: "rule", - Name: "rule", - Enabled: true, - Action: "accept", - Destinations: []string{"group1"}, - Sources: []string{"group1"}, - Bidirectional: true, - Protocol: "all", - Ports: []string{"80"}, - }, - }, - } - - srvUser := server.NewRegularUser(serviceUser) - srvUser.IsServiceUser = true - - account := &server.Account{ - Id: accountID, - Domain: "hotmail.com", - Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), - serviceUser: srvUser, - }, - Groups: map[string]*nbgroup.Group{ - "group1": { - ID: "group1", - AccountID: accountID, - Name: "group1", - Issued: "api", - Peers: maps.Keys(peersMap), - }, - }, - Settings: &server.Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: time.Hour, - }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ - Identifier: "ciclqisab2ss43jdn8q0", - Net: net.IPNet{ - IP: net.ParseIP("100.67.0.0"), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - Serial: 51, - }, - } - return account, nil }, HasConnectedChannelFunc: func(peerID string) bool {