mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-18 03:01:31 +01:00
Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
0ee56e14d9
commit
2de0777f7a
@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
)
|
)
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
accountManager.GetAccountFromPAT,
|
accountManager.GetAccountInfoFromPAT,
|
||||||
jwtValidator.ValidateAndParse,
|
jwtValidator.ValidateAndParse,
|
||||||
accountManager.MarkPATUsed,
|
accountManager.MarkPATUsed,
|
||||||
accountManager.CheckUserAccessByJWTGroups,
|
accountManager.CheckUserAccessByJWTGroups,
|
||||||
|
@ -19,8 +19,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAccountFromPATFunc function
|
// GetAccountInfoFromPATFunc function
|
||||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
|
||||||
|
|
||||||
// ValidateAndParseTokenFunc function
|
// ValidateAndParseTokenFunc function
|
||||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||||
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
|||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
getAccountFromPAT GetAccountFromPATFunc
|
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||||
validateAndParseToken ValidateAndParseTokenFunc
|
validateAndParseToken ValidateAndParseTokenFunc
|
||||||
markPATUsed MarkPATUsedFunc
|
markPATUsed MarkPATUsedFunc
|
||||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||||
@ -47,7 +47,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||||
audience string, userIdClaim string) *AuthMiddleware {
|
audience string, userIdClaim string) *AuthMiddleware {
|
||||||
if userIdClaim == "" {
|
if userIdClaim == "" {
|
||||||
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
getAccountFromPAT: getAccountFromPAT,
|
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||||
validateAndParseToken: validateAndParseToken,
|
validateAndParseToken: validateAndParseToken,
|
||||||
markPATUsed: markPATUsed,
|
markPATUsed: markPATUsed,
|
||||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||||
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
|||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||||
token, err := getTokenFromPATRequest(auth)
|
token, err := getTokenFromPATRequest(auth)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid Token: %w", err)
|
return fmt.Errorf("invalid Token: %w", err)
|
||||||
}
|
}
|
||||||
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
claimMaps := jwt.MapClaims{}
|
claimMaps := jwt.MapClaims{}
|
||||||
claimMaps[m.userIDClaim] = user.Id
|
claimMaps[m.userIDClaim] = user.Id
|
||||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||||
// Update the current request with the new context information.
|
// Update the current request with the new context information.
|
||||||
|
@ -34,6 +34,7 @@ var testAccount = &server.Account{
|
|||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
userID: {
|
userID: {
|
||||||
Id: userID,
|
Id: userID,
|
||||||
|
AccountID: accountID,
|
||||||
PATs: map[string]*server.PersonalAccessToken{
|
PATs: map[string]*server.PersonalAccessToken{
|
||||||
tokenID: {
|
tokenID: {
|
||||||
ID: tokenID,
|
ID: tokenID,
|
||||||
@ -49,11 +50,11 @@ var testAccount = &server.Account{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
|
||||||
if token == PAT {
|
if token == PAT {
|
||||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||||
}
|
}
|
||||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||||
@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
authMiddleware := NewAuthMiddleware(
|
authMiddleware := NewAuthMiddleware(
|
||||||
mockGetAccountFromPAT,
|
mockGetAccountInfoFromPAT,
|
||||||
mockValidateAndParseToken,
|
mockValidateAndParseToken,
|
||||||
mockMarkPATUsed,
|
mockMarkPATUsed,
|
||||||
mockCheckUserAccessByJWTGroups,
|
mockCheckUserAccessByJWTGroups,
|
||||||
|
Loading…
Reference in New Issue
Block a user