mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-02 20:36:25 +02:00
Add JWT group-based access control for adding new peers (#1383)
* Added function to check user access by JWT groups in the account management mock server and account manager * Refactor auth middleware for group-based JWT access control * Add group-based JWT access control on adding new peer with JWT * Remove mapping error as the token validation error is already present in grpc error codes * use GetAccountFromToken to prevent single mode issues * handle foreground login message --------- Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
This commit is contained in:
parent
65247de48d
commit
cba3c549e9
@ -151,13 +151,21 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lastError error
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
err = WithBackOff(func() error {
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
err := internal.Login(ctx, config, setupKey, jwtToken)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
|
lastError = err
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if lastError != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", lastError)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -17,11 +17,12 @@ import (
|
|||||||
|
|
||||||
"github.com/eko/gocache/v3/cache"
|
"github.com/eko/gocache/v3/cache"
|
||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/netbirdio/management-integrations/additions"
|
|
||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/management-integrations/additions"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
@ -66,6 +67,7 @@ type AccountManager interface {
|
|||||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||||
|
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
DeleteAccount(accountID, userID string) error
|
DeleteAccount(accountID, userID string) error
|
||||||
MarkPATUsed(tokenID string) error
|
MarkPATUsed(tokenID string) error
|
||||||
@ -1697,6 +1699,39 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
|||||||
return am.dnsDomain
|
return am.dnsDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||||
|
// group propagation and set the list of groups with access permissions.
|
||||||
|
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||||
|
account, _, err := am.GetAccountFromToken(claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensures JWT group synchronization to the management is enabled before,
|
||||||
|
// filtering access based on the allowed groups.
|
||||||
|
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||||
|
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||||
|
userJWTGroups := make([]string, 0)
|
||||||
|
|
||||||
|
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||||
|
if claimGroups, ok := claim.([]interface{}); ok {
|
||||||
|
for _, g := range claimGroups {
|
||||||
|
if group, ok := g.(string); ok {
|
||||||
|
userJWTGroups = append(userJWTGroups, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||||
|
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// addAllGroup to account object if it doesn't exists
|
// addAllGroup to account object if it doesn't exists
|
||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
@ -1768,3 +1803,15 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
|||||||
}
|
}
|
||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||||
|
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||||
|
for _, userGroup := range userGroups {
|
||||||
|
for _, allowedGroup := range allowedGroups {
|
||||||
|
if userGroup == allowedGroup {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -220,6 +220,10 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
|||||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
|
||||||
|
return "", status.Errorf(codes.PermissionDenied, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return claims.UserId, nil
|
return claims.UserId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,7 +316,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
|||||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
|
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
|
||||||
return nil, mapError(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var sshKey []byte
|
var sshKey []byte
|
||||||
|
@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
|
|||||||
accountManager.GetAccountFromPAT,
|
accountManager.GetAccountFromPAT,
|
||||||
jwtValidator.ValidateAndParse,
|
jwtValidator.ValidateAndParse,
|
||||||
accountManager.MarkPATUsed,
|
accountManager.MarkPATUsed,
|
||||||
accountManager.GetAccountFromToken,
|
accountManager.CheckUserAccessByJWTGroups,
|
||||||
claimsExtractor,
|
claimsExtractor,
|
||||||
authCfg.Audience,
|
authCfg.Audience,
|
||||||
authCfg.UserIDClaim,
|
authCfg.UserIDClaim,
|
||||||
|
@ -26,18 +26,18 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
|||||||
// MarkPATUsedFunc function
|
// MarkPATUsedFunc function
|
||||||
type MarkPATUsedFunc func(token string) error
|
type MarkPATUsedFunc func(token string) error
|
||||||
|
|
||||||
// GetAccountFromTokenFunc function
|
// CheckUserAccessByJWTGroupsFunc function
|
||||||
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
getAccountFromPAT GetAccountFromPATFunc
|
getAccountFromPAT GetAccountFromPATFunc
|
||||||
validateAndParseToken ValidateAndParseTokenFunc
|
validateAndParseToken ValidateAndParseTokenFunc
|
||||||
markPATUsed MarkPATUsedFunc
|
markPATUsed MarkPATUsedFunc
|
||||||
getAccountFromToken GetAccountFromTokenFunc
|
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
audience string
|
audience string
|
||||||
userIDClaim string
|
userIDClaim string
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -46,20 +46,20 @@ const (
|
|||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||||
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||||
audience string, userIdClaim string) *AuthMiddleware {
|
audience string, userIdClaim string) *AuthMiddleware {
|
||||||
if userIdClaim == "" {
|
if userIdClaim == "" {
|
||||||
userIdClaim = jwtclaims.UserIDClaim
|
userIdClaim = jwtclaims.UserIDClaim
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
getAccountFromPAT: getAccountFromPAT,
|
getAccountFromPAT: getAccountFromPAT,
|
||||||
validateAndParseToken: validateAndParseToken,
|
validateAndParseToken: validateAndParseToken,
|
||||||
markPATUsed: markPATUsed,
|
markPATUsed: markPATUsed,
|
||||||
getAccountFromToken: getAccountFromToken,
|
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||||
claimsExtractor: claimsExtractor,
|
claimsExtractor: claimsExtractor,
|
||||||
audience: audience,
|
audience: audience,
|
||||||
userIDClaim: userIdClaim,
|
userIDClaim: userIdClaim,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,34 +134,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
// group propagation and designated certain groups with access permissions.
|
// group propagation and designated certain groups with access permissions.
|
||||||
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
||||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||||
account, _, err := m.getAccountFromToken(authClaims)
|
return m.checkUserAccessByJWTGroups(authClaims)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get the account from token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensures JWT group synchronization to the management is enabled before,
|
|
||||||
// filtering access based on the allowed groups.
|
|
||||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
|
||||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
|
||||||
userJWTGroups := make([]string, 0)
|
|
||||||
|
|
||||||
if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
|
||||||
if claimGroups, ok := claim.([]interface{}); ok {
|
|
||||||
for _, g := range claimGroups {
|
|
||||||
if group, ok := g.(string); ok {
|
|
||||||
userJWTGroups = append(userJWTGroups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
|
||||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
@ -217,15 +190,3 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
|
|||||||
|
|
||||||
return authHeaderParts[1], nil
|
return authHeaderParts[1], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
|
||||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
|
||||||
for _, userGroup := range userGroups {
|
|
||||||
for _, allowedGroup := range allowedGroups {
|
|
||||||
if userGroup == allowedGroup {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
@ -73,17 +73,16 @@ func mockMarkPATUsed(token string) error {
|
|||||||
return fmt.Errorf("Should never get reached")
|
return fmt.Errorf("Should never get reached")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||||
if testAccount.Id != claims.AccountId {
|
if testAccount.Id != claims.AccountId {
|
||||||
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := testAccount.Users[claims.UserId]
|
if _, ok := testAccount.Users[claims.UserId]; !ok {
|
||||||
if !ok {
|
return fmt.Errorf("user with id %s does not exist", claims.UserId)
|
||||||
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return testAccount, user, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthMiddleware_Handler(t *testing.T) {
|
func TestAuthMiddleware_Handler(t *testing.T) {
|
||||||
@ -137,7 +136,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
mockGetAccountFromPAT,
|
mockGetAccountFromPAT,
|
||||||
mockValidateAndParseToken,
|
mockValidateAndParseToken,
|
||||||
mockMarkPATUsed,
|
mockMarkPATUsed,
|
||||||
mockGetAccountFromToken,
|
mockCheckUserAccessByJWTGroups,
|
||||||
claimsExtractor,
|
claimsExtractor,
|
||||||
audience,
|
audience,
|
||||||
userIDClaim,
|
userIDClaim,
|
||||||
|
@ -69,6 +69,7 @@ type MockAccountManager struct {
|
|||||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||||
|
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||||
DeleteAccountFunc func(accountID, userID string) error
|
DeleteAccountFunc func(accountID, userID string) error
|
||||||
GetDNSDomainFunc func() string
|
GetDNSDomainFunc func() string
|
||||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any)
|
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any)
|
||||||
@ -543,6 +544,13 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
|
|||||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||||
|
if am.CheckUserAccessByJWTGroupsFunc != nil {
|
||||||
|
return am.CheckUserAccessByJWTGroupsFunc(claims)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeers mocks GetPeers of the AccountManager interface
|
// GetPeers mocks GetPeers of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
|
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||||
if am.GetPeersFunc != nil {
|
if am.GetPeersFunc != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user