mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-22 08:03:30 +01:00
[management] Restrict accessible peers to user-owned peers for non-admins (#2618)
* Restrict accessible peers to user-owned peers for non-admin users Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add service user test Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * reuse account from token Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * return error when peer not found Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
fc4b37f7bc
commit
35c892aea3
@ -7,8 +7,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@ -16,6 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PeersHandler is a handler that returns peers of the account
|
||||
@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
// If the user is regular user and does not own the peer
|
||||
// with the given peerID return an empty list
|
||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||
peer, ok := account.Peers[peerID]
|
||||
if !ok {
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.UserID != user.Id {
|
||||
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -12,20 +13,30 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
const testPeerID = "test_peer"
|
||||
const noUpdateChannelTestPeerID = "no-update-channel"
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
testPeerID = "test_peer"
|
||||
noUpdateChannelTestPeerID = "no-update-channel"
|
||||
|
||||
adminUser = "admin_user"
|
||||
regularUser = "regular_user"
|
||||
serviceUser = "service_user"
|
||||
userIDKey ctxKey = "user_id"
|
||||
)
|
||||
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return &PeersHandler{
|
||||
@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: claims.AccountId,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
Enabled: true,
|
||||
Action: "accept",
|
||||
Destinations: []string{"group1"},
|
||||
Sources: []string{"group1"},
|
||||
Bidirectional: true,
|
||||
Protocol: "all",
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
peers[0].ID: peers[0],
|
||||
peers[1].ID: peers[1],
|
||||
},
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
"test_user": user,
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: claims.AccountId,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
},
|
||||
Serial: 51,
|
||||
},
|
||||
}, user, nil
|
||||
}
|
||||
|
||||
return account, account.Users[claims.UserId], nil
|
||||
},
|
||||
HasConnectedChannelFunc: func(peerID string) bool {
|
||||
statuses := make(map[string]struct{})
|
||||
@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
userID := r.Context().Value(userIDKey).(string)
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
UserId: userID,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||
@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessiblePeers(t *testing.T) {
|
||||
peer1 := &nbpeer.Peer{
|
||||
ID: "peer1",
|
||||
Key: "key1",
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "peer1",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
ID: "peer2",
|
||||
Key: "key2",
|
||||
IP: net.ParseIP("100.64.0.2"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "peer2",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: adminUser,
|
||||
}
|
||||
|
||||
peer3 := &nbpeer.Peer{
|
||||
ID: "peer3",
|
||||
Key: "key3",
|
||||
IP: net.ParseIP("100.64.0.3"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true},
|
||||
Name: "peer3",
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: regularUser,
|
||||
}
|
||||
|
||||
p := initTestMetaData(peer1, peer2, peer3)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
peerID string
|
||||
callerUserID string
|
||||
expectedStatus int
|
||||
expectedPeers []string
|
||||
}{
|
||||
{
|
||||
name: "non admin user can access owned peer",
|
||||
peerID: "peer1",
|
||||
callerUserID: regularUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer2", "peer3"},
|
||||
},
|
||||
{
|
||||
name: "non admin user can't access unowned peer",
|
||||
peerID: "peer2",
|
||||
callerUserID: regularUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{},
|
||||
},
|
||||
{
|
||||
name: "admin user can access owned peer",
|
||||
peerID: "peer2",
|
||||
callerUserID: adminUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer3"},
|
||||
},
|
||||
{
|
||||
name: "admin user can access unowned peer",
|
||||
peerID: "peer3",
|
||||
callerUserID: adminUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
{
|
||||
name: "service user can access unowned peer",
|
||||
peerID: "peer3",
|
||||
callerUserID: serviceUser,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedPeers: []string{"peer1", "peer2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
|
||||
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
if res.StatusCode != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response body: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var accessiblePeers []api.AccessiblePeer
|
||||
err = json.Unmarshal(body, &accessiblePeers)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
peerIDs := make([]string, len(accessiblePeers))
|
||||
for i, peer := range accessiblePeers {
|
||||
peerIDs[i] = peer.Id
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user