mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-26 22:06:21 +02:00
refactor access control middleware and user access by JWT groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -81,6 +81,7 @@ type AccountManager interface {
|
|||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
|
GetUserByID(ctx context.Context, userID string) (*User, error)
|
||||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@@ -2033,26 +2034,25 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
|||||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||||
// group propagation and set the list of groups with access permissions.
|
// group propagation and set the list of groups with access permissions.
|
||||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
accountID := claims.AccountId
|
||||||
|
if accountID == "" {
|
||||||
|
user, err := am.GetUserByID(ctx, claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to retrieve account for user %s: %v", claims.UserId, err)
|
||||||
|
}
|
||||||
|
accountID = user.AccountID
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensures JWT group synchronization to the management is enabled before,
|
// Ensures JWT group synchronization to the management is enabled before,
|
||||||
// filtering access based on the allowed groups.
|
// filtering access based on the allowed groups.
|
||||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
if settings != nil && settings.JWTGroupsEnabled {
|
||||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||||
userJWTGroups := make([]string, 0)
|
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
|
||||||
if claimGroups, ok := claim.([]interface{}); ok {
|
|
||||||
for _, g := range claimGroups {
|
|
||||||
if group, ok := g.(string); ok {
|
|
||||||
userJWTGroups = append(userJWTGroups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||||
@@ -2185,6 +2185,25 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||||
|
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||||
|
userJWTGroups := make([]string, 0)
|
||||||
|
|
||||||
|
if claim, ok := claims.Raw[claimName]; ok {
|
||||||
|
if claimGroups, ok := claim.([]interface{}); ok {
|
||||||
|
for _, g := range claimGroups {
|
||||||
|
if group, ok := g.(string); ok {
|
||||||
|
userJWTGroups = append(userJWTGroups, group)
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return userJWTGroups
|
||||||
|
}
|
||||||
|
|
||||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||||
for _, userGroup := range userGroups {
|
for _, userGroup := range userGroups {
|
||||||
|
@@ -66,7 +66,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
acMiddleware := middleware.NewAccessControl(
|
acMiddleware := middleware.NewAccessControl(
|
||||||
authCfg.Audience,
|
authCfg.Audience,
|
||||||
authCfg.UserIDClaim,
|
authCfg.UserIDClaim,
|
||||||
accountManager.GetUser)
|
accountManager.GetUserByID)
|
||||||
|
|
||||||
rootRouter := mux.NewRouter()
|
rootRouter := mux.NewRouter()
|
||||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||||
|
@@ -15,8 +15,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
// GetUser function defines a function to fetch user from Account by user id.
|
||||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
type GetUser func(ctx context.Context, id string) (*server.User, error)
|
||||||
|
|
||||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||||
type AccessControl struct {
|
type AccessControl struct {
|
||||||
@@ -47,7 +47,7 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
claims := a.claimsExtract.FromRequestContext(r)
|
claims := a.claimsExtract.FromRequestContext(r)
|
||||||
|
|
||||||
user, err := a.getUser(r.Context(), claims)
|
user, err := a.getUser(r.Context(), claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||||
|
@@ -27,6 +27,7 @@ type MockAccountManager struct {
|
|||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
||||||
|
GetUserByIDFunc func(ctx context.Context, userID string) (*server.User, error)
|
||||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@@ -408,6 +409,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
|
|||||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
|
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserByID mock implementation of GetUserByID from server.AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetUserByID(ctx context.Context, userID string) (*server.User, error) {
|
||||||
|
if am.GetUserByIDFunc != nil {
|
||||||
|
return am.GetUserByIDFunc(ctx, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||||
if am.GetUserFunc != nil {
|
if am.GetUserFunc != nil {
|
||||||
|
@@ -361,6 +361,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserByID looks up a user by provided user id.
|
||||||
|
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
|
return am.Store.GetUserByUserID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
// GetUser looks up a user by provided authorization claims.
|
// GetUser looks up a user by provided authorization claims.
|
||||||
// It will also create an account if didn't exist for this user before.
|
// It will also create an account if didn't exist for this user before.
|
||||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||||
|
Reference in New Issue
Block a user