refactor access control middleware and user access by JWT groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-09-16 13:33:36 +03:00
parent 1ef51a4ffa
commit 258b30cf48
5 changed files with 51 additions and 18 deletions

View File

@@ -81,6 +81,7 @@ type AccountManager interface {
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, userID string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -2033,26 +2034,25 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
account, _, err := am.GetAccountFromToken(ctx, claims)
accountID := claims.AccountId
if accountID == "" {
user, err := am.GetUserByID(ctx, claims.UserId)
if err != nil {
return fmt.Errorf("failed to retrieve account for user %s: %v", claims.UserId, err)
}
accountID = user.AccountID
}
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return err
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}
if settings != nil && settings.JWTGroupsEnabled {
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
@@ -2185,6 +2185,25 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
return acc
}
// extractJWTGroups extracts the group names from a JWT token's claims.
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
userJWTGroups := make([]string, 0)
if claim, ok := claims.Raw[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
} else {
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
}
}
return userJWTGroups
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {

View File

@@ -66,7 +66,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,
accountManager.GetUser)
accountManager.GetUserByID)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()

View File

@@ -15,8 +15,8 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
// GetUser function defines a function to fetch user from Account by user id.
type GetUser func(ctx context.Context, id string) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
@@ -47,7 +47,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(r.Context(), claims)
user, err := a.getUser(r.Context(), claims.UserId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)

View File

@@ -27,6 +27,7 @@ type MockAccountManager struct {
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
GetUserByIDFunc func(ctx context.Context, userID string) (*server.User, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -408,6 +409,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
}
// GetUserByID mock implementation of GetUserByID from server.AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, userID string) (*server.User, error) {
if am.GetUserByIDFunc != nil {
return am.GetUserByIDFunc(ctx, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
}
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
if am.GetUserFunc != nil {

View File

@@ -361,6 +361,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return newUser.ToUserInfo(idpUser, account.Settings)
}
// GetUserByID looks up a user by provided user id.
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
return am.Store.GetUserByUserID(ctx, id)
}
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {