From 2de0777f7af71c06f27253bbd0da0afd9f9bf8ae Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 19 Nov 2024 23:33:46 +0300 Subject: [PATCH] Refactor auth middleware Signed-off-by: bcmmbaga --- management/server/http/handler.go | 2 +- .../server/http/middleware/auth_middleware.go | 22 +++++++++---------- .../http/middleware/auth_middleware_test.go | 11 +++++----- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c3928bff6..bb6d00209 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa ) authMiddleware := middleware.NewAuthMiddleware( - accountManager.GetAccountFromPAT, + accountManager.GetAccountInfoFromPAT, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, accountManager.CheckUserAccessByJWTGroups, diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index b25aad99c..006a7872c 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -19,8 +19,8 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -// GetAccountFromPATFunc function -type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +// GetAccountInfoFromPATFunc function +type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) @@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - getAccountFromPAT GetAccountFromPATFunc + getAccountInfoFromPAT GetAccountInfoFromPATFunc validateAndParseToken ValidateAndParseTokenFunc markPATUsed MarkPATUsedFunc checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc @@ -47,7 +47,7 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, +func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, audience string, userIdClaim string) *AuthMiddleware { if userIdClaim == "" { @@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse } return &AuthMiddleware{ - getAccountFromPAT: getAccountFromPAT, + getAccountInfoFromPAT: getAccountInfoFromPAT, validateAndParseToken: validateAndParseToken, markPATUsed: markPATUsed, checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, @@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j // CheckPATFromRequest checks if the PAT is valid func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { token, err := getTokenFromPATRequest(auth) - - // If an error occurs, call the error handler and return an error if err != nil { - return fmt.Errorf("Error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } - account, user, pat, err := m.getAccountFromPAT(r.Context(), token) + user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token) if err != nil { return fmt.Errorf("invalid Token: %w", err) } @@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps := jwt.MapClaims{} claimMaps[m.userIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index fdfb0ea24..0e0872d31 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -33,7 +33,8 @@ var testAccount = &server.Account{ Domain: domain, Users: map[string]*server.User{ userID: { - Id: userID, + Id: userID, + AccountID: accountID, PATs: map[string]*server.PersonalAccessToken{ tokenID: { ID: tokenID, @@ -49,11 +50,11 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) { if token == PAT { - return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } - return nil, nil, nil, fmt.Errorf("PAT invalid") + return nil, nil, "", "", fmt.Errorf("PAT invalid") } func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { @@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { ) authMiddleware := NewAuthMiddleware( - mockGetAccountFromPAT, + mockGetAccountInfoFromPAT, mockValidateAndParseToken, mockMarkPATUsed, mockCheckUserAccessByJWTGroups,