mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 11:00:06 +02:00
refactor access control middleware and user access by JWT groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -81,6 +81,7 @@ type AccountManager interface {
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, userID string) (*User, error)
|
||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
@@ -2033,26 +2034,25 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and set the list of groups with access permissions.
|
||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
||||
accountID := claims.AccountId
|
||||
if accountID == "" {
|
||||
user, err := am.GetUserByID(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve account for user %s: %v", claims.UserId, err)
|
||||
}
|
||||
accountID = user.AccountID
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings != nil && settings.JWTGroupsEnabled {
|
||||
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
@@ -2185,6 +2185,25 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
||||
return acc
|
||||
}
|
||||
|
||||
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[claimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return userJWTGroups
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
|
@@ -66,7 +66,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
accountManager.GetUser)
|
||||
accountManager.GetUserByID)
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
// GetUser function defines a function to fetch user from Account by user id.
|
||||
type GetUser func(ctx context.Context, id string) (*server.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
@@ -47,7 +47,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
user, err := a.getUser(r.Context(), claims)
|
||||
user, err := a.getUser(r.Context(), claims.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
|
@@ -27,6 +27,7 @@ type MockAccountManager struct {
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserByIDFunc func(ctx context.Context, userID string) (*server.User, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
@@ -408,6 +409,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
|
||||
}
|
||||
|
||||
// GetUserByID mock implementation of GetUserByID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUserByID(ctx context.Context, userID string) (*server.User, error) {
|
||||
if am.GetUserByIDFunc != nil {
|
||||
return am.GetUserByIDFunc(ctx, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
|
||||
}
|
||||
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||
if am.GetUserFunc != nil {
|
||||
|
@@ -361,6 +361,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||
}
|
||||
|
||||
// GetUserByID looks up a user by provided user id.
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
return am.Store.GetUserByUserID(ctx, id)
|
||||
}
|
||||
|
||||
// GetUser looks up a user by provided authorization claims.
|
||||
// It will also create an account if didn't exist for this user before.
|
||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||
|
Reference in New Issue
Block a user