mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 16:54:16 +01:00
Enable JWT group-based user authorization (#1368)
* Extend management API to support list of allowed JWT groups (#1366) * Add JWTAllowGroups settings to account management * Return an empty group list if jwt allow groups is not set * Add JwtAllowGroups to account settings in handler test * Add JWT group-based user authorization (#1373) * Add JWTAllowGroups settings to account management * Return an empty group list if jwt allow groups is not set * Add JwtAllowGroups to account settings in handler test * Implement user access validation authentication based on JWT groups * Remove the slices package import due to compatibility issues with the gitHub workflow(s) Go version * Refactor auth middleware and test for extracted claim handling * Optimize JWT group check in auth middleware to cover nil and empty allowed groups
This commit is contained in:
parent
5ecafef5d2
commit
d275d411aa
@ -164,6 +164,9 @@ type Settings struct {
|
||||
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||
JWTGroupsClaimName string
|
||||
|
||||
// JWTAllowGroups list of groups to which users are allowed access
|
||||
JWTAllowGroups []string `gorm:"serializer:json"`
|
||||
|
||||
// Extra is a dictionary of Account settings
|
||||
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
|
||||
}
|
||||
@ -176,6 +179,7 @@ func (s *Settings) Copy() *Settings {
|
||||
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
|
||||
JWTAllowGroups: s.JWTAllowGroups,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
|
@ -91,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
if req.Settings.JwtGroupsClaimName != nil {
|
||||
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||
}
|
||||
if req.Settings.JwtAllowGroups != nil {
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
|
||||
if err != nil {
|
||||
@ -128,12 +131,18 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
func toAccountResponse(account *server.Account) *api.Account {
|
||||
jwtAllowGroups := account.Settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
}
|
||||
|
||||
settings := api.AccountSettings{
|
||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
||||
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
|
||||
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
||||
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
||||
JwtAllowGroups: &jwtAllowGroups,
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
|
@ -95,6 +95,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr(""),
|
||||
JwtGroupsEnabled: br(false),
|
||||
JwtAllowGroups: &[]string{},
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@ -112,6 +113,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr(""),
|
||||
JwtGroupsEnabled: br(false),
|
||||
JwtAllowGroups: &[]string{},
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@ -121,7 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
@ -129,6 +131,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr("roles"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{"test"},
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@ -146,6 +149,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
GroupsPropagationEnabled: br(true),
|
||||
JwtGroupsClaimName: sr("groups"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{},
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
@ -66,6 +66,12 @@ components:
|
||||
description: Name of the claim from which we extract groups names to add it to account groups.
|
||||
type: string
|
||||
example: "roles"
|
||||
jwt_allow_groups:
|
||||
description: List of groups to which users are allowed access
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: Administrators
|
||||
extra:
|
||||
$ref: '#/components/schemas/AccountExtraSettings'
|
||||
required:
|
||||
|
@ -160,6 +160,9 @@ type AccountSettings struct {
|
||||
// GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user
|
||||
GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"`
|
||||
|
||||
// JwtAllowGroups List of groups to which users are allowed access
|
||||
JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"`
|
||||
|
||||
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
|
||||
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
|
||||
|
||||
|
@ -34,12 +34,20 @@ type emptyObject struct {
|
||||
|
||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.GetAccountFromToken,
|
||||
claimsExtractor,
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim)
|
||||
authCfg.UserIDClaim,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
@ -60,11 +68,6 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
||||
AuthCfg: authCfg,
|
||||
}
|
||||
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
)
|
||||
|
||||
integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
|
||||
api.addAccountsEndpoint()
|
||||
api.addPeersEndpoint()
|
||||
|
@ -26,11 +26,16 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(token string) error
|
||||
|
||||
// GetAccountFromTokenFunc function
|
||||
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
getAccountFromToken GetAccountFromTokenFunc
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
audience string
|
||||
userIDClaim string
|
||||
}
|
||||
@ -40,14 +45,19 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware {
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
userIdClaim = jwtclaims.UserIDClaim
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
getAccountFromToken: getAccountFromToken,
|
||||
claimsExtractor: claimsExtractor,
|
||||
audience: audience,
|
||||
userIDClaim: userIdClaim,
|
||||
}
|
||||
@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.verifyUserAccess(validatedToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we get here, everything worked and we can set the
|
||||
// user property in context.
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
|
||||
@ -115,6 +129,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyUserAccess checks if a user, based on a validated JWT token,
|
||||
// is allowed access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
account, _, err := m.getAccountFromToken(authClaims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get the account from token: %w", err)
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
@ -168,3 +217,15 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
for _, allowedGroup := range allowedGroups {
|
||||
if userGroup == allowedGroup {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server
|
||||
|
||||
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{}, nil
|
||||
return &jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
userIDClaim: userID,
|
||||
audience + jwtclaims.AccountIDSuffix: accountID,
|
||||
},
|
||||
Valid: true,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
@ -66,6 +73,19 @@ func mockMarkPATUsed(token string) error {
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
}
|
||||
|
||||
user, ok := testAccount.Users[claims.UserId]
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
|
||||
}
|
||||
|
||||
return testAccount, user, nil
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
@ -108,7 +128,20 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
// do nothing
|
||||
})
|
||||
|
||||
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim)
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockGetAccountFromToken,
|
||||
claimsExtractor,
|
||||
audience,
|
||||
userIDClaim,
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user