Hide content based on user role (#541)

This commit is contained in:
Misha Bragin 2022-11-05 10:24:50 +01:00 committed by GitHub
parent e8d82c1bd3
commit 4321b71984
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 305 additions and 142 deletions

View File

@ -47,15 +47,16 @@ type AccountManager interface {
) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID string) ([]*SetupKey, error)
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error)
@ -66,7 +67,7 @@ type AccountManager interface {
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
UpdatePeerSSHKey(peerKey string, sshKey string) error
GetUsersFromAccount(accountId string) ([]*UserInfo, error)
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
@ -75,17 +76,17 @@ type AccountManager interface {
GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountId, ruleID string) (*Rule, error)
GetRule(accountID, ruleID, userID string) (*Rule, error)
SaveRule(accountID string, rule *Rule) error
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountId, ruleID string) error
ListRules(accountId string) ([]*Rule, error)
GetRoute(accountID, routeID string) (*route.Route, error)
ListRules(accountID, userID string) ([]*Rule, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
SaveRoute(accountID string, route *route.Route) error
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error)
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
@ -142,6 +143,16 @@ type UserInfo struct {
Status string `json:"-"`
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID)
}
return user, nil
}
func (a *Account) Copy() *Account {
peers := map[string]*Peer{}
for id, peer := range a.Peers {

View File

@ -874,7 +874,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user
}
userInfos, err := manager.GetUsersFromAccount(accountId)
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
if err != nil {
t.Fatal(err)
}

View File

@ -11,6 +11,12 @@ const (
AccountNotFound ErrorType = iota
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed ErrorType = iota
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
UserNotFound ErrorType = iota
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied ErrorType = iota
)
// ErrorType is a type of the Error

View File

@ -33,7 +33,7 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -50,7 +50,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -111,7 +111,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -226,7 +226,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -260,7 +260,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -295,7 +295,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return

View File

@ -25,7 +25,7 @@ var TestPeers = map[string]*server.Peer{
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(groups ...*server.Group) *Groups {
func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
return &Groups{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error {
@ -72,6 +72,9 @@ func initGroupTestData(groups ...*server.Group) *Groups {
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
@ -120,7 +123,8 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
p := initGroupTestData(group)
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@ -270,7 +274,8 @@ func TestWriteGroup(t *testing.T) {
},
}
p := initGroupTestData()
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@ -1,17 +1,12 @@
package middleware
import (
"context"
"fmt"
"net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
IsUserAdminProperty = "isAdminUser"
)
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
@ -50,10 +45,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
}
}
newRequest := r.Clone(context.WithValue(r.Context(), IsUserAdminProperty, ok)) //nolint
// Update the current request with the new context information.
*r = *newRequest
h.ServeHTTP(w, r)
})
}

View File

@ -30,7 +30,7 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -53,7 +53,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -84,7 +84,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -131,7 +131,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -202,7 +202,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -226,7 +226,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@ -29,6 +29,9 @@ const (
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{

View File

@ -56,7 +56,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -95,15 +95,20 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
return
}
respBody := []*api.Peer{}
for _, peer := range account.Peers {
for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account))
}
writeJSONObject(w, respBody)

View File

@ -16,15 +16,21 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initTestMetaData(peer ...*server.Peer) *Peers {
func initTestMetaData(peers ...*server.Peer) *Peers {
return &Peers{
accountManager: &mock_server.MockAccountManager{
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*server.Peer{
"test_peer": peer[0],
"test_peer": peers[0],
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil
},

View File

@ -33,16 +33,24 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
routes, err := h.accountManager.ListRoutes(account.Id)
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
@ -56,7 +64,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -104,7 +112,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -117,7 +125,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID)
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
return
@ -177,7 +185,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -190,7 +198,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID)
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
if err != nil {
log.Error(err)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
@ -334,7 +342,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -363,7 +371,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -375,7 +383,7 @@ func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID)
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
http.Error(w, "route not found", http.StatusNotFound)
return

View File

@ -51,12 +51,15 @@ var testingAccount = &server.Account{
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
func initRoutesTestData() *Routes {
return &Routes{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID string) (*route.Route, error) {
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}

View File

@ -31,15 +31,29 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
rules := []*api.Rule{}
for _, r := range account.Rules {
for _, r := range accountRules {
rules = append(rules, toRuleResponse(account, r))
}
@ -48,7 +62,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -118,7 +132,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -275,7 +289,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -332,7 +346,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -356,7 +370,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@ -370,7 +384,7 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID)
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil {
http.Error(w, "rule not found", http.StatusNotFound)
return

View File

@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}
return nil
},
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
@ -75,6 +75,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"F": {ID: "F"},
"G": {ID: "G"},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil
},
},

View File

@ -6,15 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"strings"
"time"
"unicode/utf8"
)
// SetupKeys is a handler that returns a list of setup keys of the account
@ -34,7 +31,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -84,7 +81,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -98,7 +95,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
return
}
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
@ -110,17 +107,12 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
return
}
isUserAdmin, ok := r.Context().Value(middleware.IsUserAdminProperty).(bool)
if !ok || !isUserAdmin {
key = hideKey(key)
}
writeSuccess(w, key)
}
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -176,19 +168,14 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
isUserAdmin, ok := r.Context().Value(middleware.IsUserAdminProperty).(bool)
if !ok {
isUserAdmin = false
}
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -196,23 +183,12 @@ func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Reques
}
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
k := key.Copy()
if !isUserAdmin {
k = hideKey(key)
}
apiSetupKeys = append(apiSetupKeys, toResponseBody(k))
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
}
writeJSONObject(w, apiSetupKeys)
}
func hideKey(key *server.SetupKey) *server.SetupKey {
k := key.Copy()
prefix := k.Key[0:5]
k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix))
return k
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")

View File

@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
@ -29,13 +28,17 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
@ -50,7 +53,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
@ -68,7 +71,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},
@ -76,7 +79,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: user.Id,
Domain: "hotmail.com",
AccountId: testAccountID,
}
@ -89,6 +92,8 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
@ -154,13 +159,12 @@ func TestSetupKeysHandlers(t *testing.T) {
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = req.Clone(context.WithValue(context.TODO(), "isAdminUser", true)) //nolint
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")

View File

@ -34,7 +34,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -87,7 +87,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
http.Error(w, "", http.StatusNotFound)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
}
@ -132,12 +132,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
data, err := h.accountManager.GetUsersFromAccount(account.Id)
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@ -27,7 +27,7 @@ func initUsers(user ...*server.User) *UserHandler {
Users: users,
}, nil
},
GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range user {
users = append(users, &server.UserInfo{
@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler {
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: "1",
Domain: "hotmail.com",
AccountId: "test_id",
}

View File

@ -56,16 +56,22 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, error) {
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(jwtClaims)
account, err := accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
}
return account, nil
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
}
func toHTTPError(err error, w http.ResponseWriter) {

View File

@ -14,12 +14,13 @@ type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExistsFunc func(accountId string) (*bool, error)
GetPeerFunc func(peerKey string) (*server.Peer, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
@ -35,23 +36,23 @@ type MockAccountManager struct {
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
ListRulesFunc func(accountID, userID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
GetRouteFunc func(accountID, routeID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, route *route.Route) error
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error)
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error)
@ -64,9 +65,9 @@ type MockAccountManager struct {
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID)
return am.GetUsersFromAccountFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
}
@ -272,9 +273,9 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
}
// GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) {
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID)
return am.GetRuleFunc(accountID, ruleID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
}
@ -304,9 +305,9 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
}
// ListRules mock implementation of ListRules from server.AccountManager interface
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) {
func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID)
return am.ListRulesFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
}
@ -345,16 +346,16 @@ func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (*
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
if am.GetRouteFunc != nil {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
// GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) {
func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID)
return am.GetRouteFunc(accountID, routeID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
}
@ -384,9 +385,9 @@ func (am *MockAccountManager) DeleteRoute(accountID, routeID string) error {
}
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface
func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, error) {
func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
if am.ListRoutesFunc != nil {
return am.ListRoutesFunc(accountID)
return am.ListRoutesFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
}
@ -401,18 +402,18 @@ func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKe
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) {
func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, keyID)
return am.GetSetupKeyFunc(accountID, userID, keyID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) {
func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID)
return am.ListSetupKeysFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
@ -489,3 +490,11 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeersFunc is not implemented")
}

View File

@ -81,6 +81,33 @@ func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
return peer, nil
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
peers := make([]*Peer, 0, len(account.Peers))
for _, peer := range account.Peers {
if !user.IsAdmin() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
peers = append(peers, peer.Copy())
}
return peers, nil
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
am.mux.Lock()

View File

@ -134,7 +134,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
return
}
rules, err := manager.ListRules(account.Id)
rules, err := manager.ListRules(account.Id, userId)
if err != nil {
t.Errorf("expecting to get a list of rules, got failure %v", err)
return

View File

@ -60,7 +60,7 @@ type RouteUpdateOperation struct {
}
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) {
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -69,6 +69,15 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Rou
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
wantedRoute, found := account.Routes[routeID]
if found {
return wantedRoute, nil
@ -325,7 +334,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, error) {
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -334,6 +343,15 @@ func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, e
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
}
routes := make([]*route.Route, 0, len(account.Routes))
for _, item := range account.Routes {
routes = append(routes, item)

View File

@ -740,7 +740,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
err = am.SaveGroup(account.Id, newGroup)
require.NoError(t, err)
rules, err := am.ListRules(account.Id)
rules, err := am.ListRules(account.Id, "testingUser")
require.NoError(t, err)
defaultRule := rules[0]

View File

@ -89,7 +89,7 @@ func (r *Rule) Copy() *Rule {
}
// GetRule of ACL from the store
func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) {
func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -98,6 +98,15 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules")
}
rule, ok := account.Rules[ruleID]
if ok {
return rule, nil
@ -222,7 +231,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
}
// ListRules of ACL from the store
func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) {
func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -231,6 +240,15 @@ func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) {
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules")
}
rules := make([]*Rule, 0, len(account.Rules))
for _, item := range account.Rules {
rules = append(rules, item)

View File

@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"
)
const (
@ -100,6 +101,15 @@ func (key *SetupKey) Copy() *SetupKey {
}
}
// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func (key *SetupKey) HiddenCopy() *SetupKey {
k := key.Copy()
prefix := k.Key[0:5]
k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix))
return k
}
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
@ -238,7 +248,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
}
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) {
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
@ -246,16 +256,27 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, e
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys {
keys = append(keys, key.Copy())
var k *SetupKey
if !user.IsAdmin() {
k = key.HiddenCopy()
} else {
k = key.Copy()
}
keys = append(keys, k)
}
return keys, nil
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) {
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -264,6 +285,11 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey
return nil, status.Errorf(codes.NotFound, "account not found")
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
var foundKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyID {
@ -280,5 +306,9 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey
foundKey.UpdatedAt = foundKey.CreatedAt
}
if !user.IsAdmin() {
foundKey = foundKey.HiddenCopy()
}
return foundKey, nil
}

View File

@ -46,6 +46,11 @@ type User struct {
AutoGroups []string
}
// IsAdmin returns true if user is an admin, false otherwise
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
// toUserInfo converts a User object to a UserInfo object.
func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups
@ -288,13 +293,19 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim
return user.Role == UserRoleAdmin, nil
}
// GetUsersFromAccount performs a batched request for users from IDP by account ID
func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) {
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) {
account, err := am.GetAccountById(accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
users := make(map[string]struct{}, len(account.Users))
@ -311,8 +322,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, user := range account.Users {
info, err := user.toUserInfo(nil)
for _, accountUser := range account.Users {
if !user.IsAdmin() && user.Id != accountUser.Id {
// if user is not an admin then show only current user and do not show other users
continue
}
info, err := accountUser.toUserInfo(nil)
if err != nil {
return nil, err
}
@ -322,6 +337,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
}
for _, queriedUser := range queriedUsers {
if !user.IsAdmin() && user.Id != queriedUser.ID {
// if user is not an admin then show only current user and do not show other users
continue
}
if localUser, contains := account.Users[queriedUser.ID]; contains {
info, err := localUser.toUserInfo(queriedUser)