Refactor auth middleware

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-11-01 14:02:09 +03:00
parent e73b5da42b
commit fed48de83f
9 changed files with 102 additions and 104 deletions

View File

@ -9,9 +9,9 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@ -47,7 +47,7 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(r.Context(), token)
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.