mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 17:43:38 +01:00
Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
0ee56e14d9
commit
2de0777f7a
@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
accountManager.GetAccountInfoFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
|
@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
@ -47,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
|
@ -34,6 +34,7 @@ var testAccount = &server.Account{
|
||||
Users: map[string]*server.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
AccountID: accountID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
@ -49,11 +50,11 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountFromPAT,
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
|
Loading…
Reference in New Issue
Block a user