diff --git a/management/server/account.go b/management/server/account.go index be51e745d..27bf5606e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1123,6 +1123,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +// MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { unlock := am.Store.AcquireGlobalLock() defer unlock() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 54889466e..d211a6ef2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -17,8 +17,13 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// GetAccountFromPATFunc function type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + +// ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) + +// MarkPATUsedFunc function type MarkPATUsedFunc func(token string) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens @@ -29,8 +34,10 @@ type AuthMiddleware struct { audience string } +type key string + const ( - userProperty = "user" + userProperty key = "user" ) // NewAuthMiddleware instance constructor @@ -44,13 +51,13 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse } // Handler method of the middleware which authenticates a user either by JWT claims or by PAT -func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { +func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.Split(r.Header.Get("Authorization"), " ") authType := auth[0] switch strings.ToLower(authType) { case "bearer": - err := a.CheckJWTFromRequest(w, r) + err := m.CheckJWTFromRequest(w, r) if err != nil { log.Debugf("Error when validating JWT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) @@ -58,7 +65,7 @@ func (a *AuthMiddleware) Handler(h http.Handler) http.Handler { } h.ServeHTTP(w, r) case "token": - err := a.CheckPATFromRequest(w, r) + err := m.CheckPATFromRequest(w, r) if err != nil { log.Debugf("Error when validating PAT claims: %s", err.Error()) util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w) @@ -93,7 +100,7 @@ func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Requ // If we get here, everything worked and we can set the // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint + newRequest := r.WithContext(context.WithValue(r.Context(), string(userProperty), validatedToken)) // nolint // Update the current request with the new context information. *r = *newRequest return nil diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 5063d7b91..b5f30b8d6 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -9,11 +9,16 @@ import ( type key string const ( - TokenUserProperty key = "user" - AccountIDSuffix key = "wt_account_id" - DomainIDSuffix key = "wt_account_domain" + // TokenUserProperty key for the user property in the request context + TokenUserProperty key = "user" + // AccountIDSuffix suffix for the account id claim + AccountIDSuffix key = "wt_account_id" + // DomainIDSuffix suffix for the domain id claim + DomainIDSuffix key = "wt_account_domain" + // DomainCategorySuffix suffix for the domain category claim DomainCategorySuffix key = "wt_account_domain_category" - UserIDClaim key = "sub" + // UserIDClaim claim for the user id + UserIDClaim key = "sub" ) // Extract function type diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index d324c5ab3..ee9513c57 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -58,10 +58,12 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } +// JWTValidator struct to handle token validation and parsing type JWTValidator struct { options Options } +// NewJWTValidator constructor func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { @@ -102,6 +104,7 @@ func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTV }, nil } +// ValidateAndParse validates the token and returns the parsed token func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { // If the token is empty... if token == "" {