diff --git a/management/server/account.go b/management/server/account.go index 018229278..53672f5c6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -55,11 +55,10 @@ type AccountManager interface { SaveUser(accountID, userID string, update *User) (*UserInfo, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) - GetAccountByUserID(userID string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) MarkPATUsed(tokenID string) error - IsUserAdmin(userID string) (bool, error) + IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error) diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index c4f79f8cb..5f8389dfa 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -12,7 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" ) -type IsUserAdminFunc func(userID string) (bool, error) +type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { @@ -37,7 +37,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := a.claimsExtract.FromRequestContext(r) - ok, err := a.isUserAdmin(claims.UserId) + ok, err := a.isUserAdmin(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b39d27b27..5d4eb709f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,12 +15,11 @@ import ( type MockAccountManager struct { GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) - GetAccountByUserIDFunc func(userID string) (*server.Account, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - IsUserAdminFunc func(userID string) (bool, error) + IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExistsFunc func(accountId string) (*bool, error) GetPeerByKeyFunc func(peerKey string) (*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) @@ -113,14 +112,6 @@ func (am *MockAccountManager) GetOrCreateAccountByUser( ) } -// GetAccountByUserID mock implementation of GetAccountByUserID from server.AccountManager interface -func (am *MockAccountManager) GetAccountByUserID(userID string) (*server.Account, error) { - if am.GetAccountByUserIDFunc != nil { - return am.GetAccountByUserIDFunc(userID) - } - return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserID is not implemented") -} - // CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface func (am *MockAccountManager) CreateSetupKey( accountID string, @@ -395,9 +386,9 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst } // IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface -func (am *MockAccountManager) IsUserAdmin(userID string) (bool, error) { +func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { if am.IsUserAdminFunc != nil { - return am.IsUserAdminFunc(userID) + return am.IsUserAdminFunc(claims) } return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented") } diff --git a/management/server/user.go b/management/server/user.go index 00eeb83c7..ea7dba2d2 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) @@ -573,19 +574,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) return account, nil } -// GetAccountByUserID returns an existing account for a given user id -func (am *DefaultAccountManager) GetAccountByUserID(userID string) (*Account, error) { - return am.Store.GetAccountByUser(userID) -} - // IsUserAdmin looks up a user by his ID and returns true if he is an admin -func (am *DefaultAccountManager) IsUserAdmin(userID string) (bool, error) { - account, err := am.GetAccountByUserID(userID) +func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { + account, _, err := am.GetAccountFromToken(claims) if err != nil { return false, fmt.Errorf("get account: %v", err) } - user, ok := account.Users[userID] + user, ok := account.Users[claims.UserId] if !ok { return false, status.Errorf(status.NotFound, "user not found") } diff --git a/management/server/user_test.go b/management/server/user_test.go index b9876be4e..504d231e5 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/jwtclaims" ) const ( @@ -453,7 +454,11 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - ok, err := am.IsUserAdmin(mockUserID) + claims := jwtclaims.AuthorizationClaims{ + UserId: mockUserID, + } + + ok, err := am.IsUserAdmin(claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -479,7 +484,11 @@ func TestUser_IsUserAdmin_ForUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - ok, err := am.IsUserAdmin(mockUserID) + claims := jwtclaims.AuthorizationClaims{ + UserId: mockUserID, + } + + ok, err := am.IsUserAdmin(claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) }