diff --git a/management/server/account.go b/management/server/account.go index 7159aa9ac..858b0bcda 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -81,6 +81,7 @@ type AccountManager interface { GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error + GetUserByID(ctx context.Context, userID string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -2033,26 +2034,25 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID := claims.AccountId + if accountID == "" { + user, err := am.GetUserByID(ctx, claims.UserId) + if err != nil { + return fmt.Errorf("failed to retrieve account for user %s: %v", claims.UserId, err) + } + accountID = user.AccountID + } + + settings, err := am.Store.GetAccountSettings(ctx, accountID) if err != nil { return err } // Ensures JWT group synchronization to the management is enabled before, // filtering access based on the allowed groups. - if account.Settings != nil && account.Settings.JWTGroupsEnabled { - if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { - userJWTGroups := make([]string, 0) - - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { - for _, g := range claimGroups { - if group, ok := g.(string); ok { - userJWTGroups = append(userJWTGroups, group) - } - } - } - } + if settings != nil && settings.JWTGroupsEnabled { + if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { + userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) if !userHasAllowedGroup(allowedGroups, userJWTGroups) { return fmt.Errorf("user does not belong to any of the allowed JWT groups") @@ -2185,6 +2185,25 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac return acc } +// extractJWTGroups extracts the group names from a JWT token's claims. +func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { + userJWTGroups := make([]string, 0) + + if claim, ok := claims.Raw[claimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } else { + log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) + } + } + } + } + + return userJWTGroups +} + // userHasAllowedGroup checks if a user belongs to any of the allowed groups. func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { for _, userGroup := range userGroups { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ef94f22b9..4dd3ecef1 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -66,7 +66,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa acMiddleware := middleware.NewAccessControl( authCfg.Audience, authCfg.UserIDClaim, - accountManager.GetUser) + accountManager.GetUserByID) rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 0ad250f43..d774eb72a 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -15,8 +15,8 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) +// GetUser function defines a function to fetch user from Account by user id. +type GetUser func(ctx context.Context, id string) (*server.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { @@ -47,7 +47,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { claims := a.claimsExtract.FromRequestContext(r) - user, err := a.getUser(r.Context(), claims) + user, err := a.getUser(r.Context(), claims.UserId) if err != nil { log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 495325252..476edf19f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,6 +27,7 @@ type MockAccountManager struct { expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetUserByIDFunc func(ctx context.Context, userID string) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -408,6 +409,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented") } +// GetUserByID mock implementation of GetUserByID from server.AccountManager interface +func (am *MockAccountManager) GetUserByID(ctx context.Context, userID string) (*server.User, error) { + if am.GetUserByIDFunc != nil { + return am.GetUserByIDFunc(ctx, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") +} + // GetUser mock implementation of GetUser from server.AccountManager interface func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { if am.GetUserFunc != nil { diff --git a/management/server/user.go b/management/server/user.go index 727bc5c6b..c8561685f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -361,6 +361,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } +// GetUserByID looks up a user by provided user id. +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { + return am.Store.GetUserByUserID(ctx, id) +} + // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {