Merge pull request #774 from netbirdio/feature/add_pat_middleware

Feature/add pat middleware
This commit is contained in:
pascal-fischer 2023-04-03 12:09:11 +02:00 committed by GitHub
commit c54fb9643c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 537 additions and 295 deletions

View File

@ -19,25 +19,28 @@ import (
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
@ -179,13 +182,22 @@ var (
tlsEnabled = true
}
jwtValidator, err := jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)
}
httpAPIAuthCfg := httpapi.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg)
httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

View File

@ -56,6 +56,7 @@ type AccountManager interface {
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
@ -1122,6 +1123,39 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return nil
}
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
unlock := am.Store.AcquireGlobalLock()
user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
unlock()
unlock = am.Store.AcquireAccountLock(account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now()
return am.Store.SaveAccount(account)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength {

View File

@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
}
err = am.MarkPATUsed("tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount("account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
}
func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {

View File

@ -6,11 +6,10 @@ import (
"strings"
"time"
pb "github.com/golang/protobuf/proto" //nolint
pb "github.com/golang/protobuf/proto" // nolint
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/golang/protobuf/ptypes/timestamp"
@ -33,7 +32,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager
config *Config
turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
}
@ -47,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, err
}
var jwtMiddleware *middleware.JWTMiddleware
var jwtValidator *jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtMiddleware, err = middleware.NewJwtMiddleware(
jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation)
@ -88,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
accountManager: accountManager,
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
}, nil
@ -189,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
}
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtMiddleware == nil {
return "", status.Error(codes.Internal, "no jwt middleware set")
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken)
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}

View File

@ -8,6 +8,7 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@ -25,18 +26,17 @@ type apiHandler struct {
AuthCfg AuthCfg
}
// EmptyObject is an empty struct used to return empty JSON object
type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware(
authCfg.Issuer,
authCfg.Audience,
authCfg.KeysLocation)
if err != nil {
return nil, err
}
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
authCfg.Audience)
corsMiddleware := cors.AllowAll()
@ -49,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
metricsMiddleware := appMetrics.HTTPMiddleware()
router := rootRouter.PathPrefix("/api").Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{
Router: router,
@ -70,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
api.addDNSSettingEndpoint()
api.addEventsEndpoint()
err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err

View File

@ -2,6 +2,9 @@ package middleware
import (
"net/http"
"regexp"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
@ -39,10 +42,22 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if !ok {
switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
if err != nil {
log.Debugf("Regex failed")
util.WriteError(status.Errorf(status.Internal, ""), w)
return
}
if ok {
log.Debugf("Valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return
}

View File

@ -0,0 +1,173 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
audience string
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware {
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
audience: audience,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := auth[0]
switch strings.ToLower(authType) {
case "bearer":
err := m.CheckJWTFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.CheckPATFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromJWTRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
if err != nil {
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromPATRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.ExpirationDate) {
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
if err != nil {
return err
}
claimMaps := jwt.MapClaims{}
claimMaps[jwtclaims.UserIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
// the PAT token from the Authorization header.
func getTokenFromPATRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
}
return authHeaderParts[1], nil
}

View File

@ -0,0 +1,123 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
)
const (
audience = "audience"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
Id: accountID,
Domain: domain,
Users: map[string]*server.User{
userID: {
Id: userID,
PATs: map[string]*server.PersonalAccessToken{
tokenID: {
ID: tokenID,
Name: "My first token",
HashedToken: "someHash",
ExpirationDate: time.Now().AddDate(0, 0, 7),
CreatedBy: userID,
CreatedAt: time.Now(),
LastUsed: time.Now(),
},
},
},
},
}
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{}, nil
}
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct {
name string
authHeader string
expectedStatusCode int
}{
{
name: "Valid PAT Token",
authHeader: "Token " + PAT,
expectedStatusCode: 200,
},
{
name: "Invalid PAT Token",
authHeader: "Token " + wrongToken,
expectedStatusCode: 401,
},
{
name: "Valid JWT Token",
authHeader: "Bearer " + JWT,
expectedStatusCode: 200,
},
{
name: "Invalid JWT Token",
authHeader: "Bearer " + wrongToken,
expectedStatusCode: 401,
},
{
name: "Basic Auth",
authHeader: "Basic " + PAT,
expectedStatusCode: 401,
},
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// do nothing
})
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience)
handlerToTest := authMiddleware.Handler(nextHandler)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://testing", nil)
req.Header.Set("Authorization", tc.authHeader)
rec := httptest.NewRecorder()
handlerToTest.ServeHTTP(rec, req)
if rec.Result().StatusCode != tc.expectedStatusCode {
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode)
}
})
}
}

View File

@ -1,254 +0,0 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
)
// A function called whenever an error is encountered
type errorHandler func(w http.ResponseWriter, r *http.Request, err string)
// TokenExtractor is a function that takes a request as input and returns
// either a token or an error. An error should only be returned if an attempt
// to specify a token was found, but the information was somehow incorrectly
// formed. In the case where a token is simply not present, this should not
// be treated as an error. An empty string should be returned in that case.
type TokenExtractor func(r *http.Request) (string, error)
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
ErrorHandler errorHandler
// A boolean indicating if the credentials are required or not
// Default value: false
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Extractor TokenExtractor
// Debug flag turns on debugging output
// Default: false
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
type JWTMiddleware struct {
Options Options
}
func OnError(w http.ResponseWriter, r *http.Request, err string) {
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
}
// New constructs a new Secure instance with supplied options.
func New(options ...Options) *JWTMiddleware {
var opts Options
if len(options) == 0 {
opts = Options{}
} else {
opts = options[0]
}
if opts.UserProperty == "" {
opts.UserProperty = "user"
}
if opts.ErrorHandler == nil {
opts.ErrorHandler = OnError
}
if opts.Extractor == nil {
opts.Extractor = FromAuthHeader
}
return &JWTMiddleware{
Options: opts,
}
}
func (m *JWTMiddleware) logf(format string, args ...interface{}) {
if m.Options.Debug {
log.Printf(format, args...)
}
}
// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere.
func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not call next.
if err == nil && next != nil {
next(w, r)
}
}
func (m *JWTMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Let secure process the request. If it returns an error,
// that indicates the request should not continue.
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not continue.
if err != nil {
log.Errorf("received an error while validating the JWT token: %s. "+
"Review your IDP configuration and ensure that "+
"settings are in sync between dashboard and management", err)
return
}
h.ServeHTTP(w, r)
})
}
// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func FromAuthHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// FromParameter returns a function that extracts the token from the specified
// query string parameter
func FromParameter(param string) TokenExtractor {
return func(r *http.Request) (string, error) {
return r.URL.Query().Get(param), nil
}
}
// FromFirst returns a function that runs multiple token extractors and takes the
// first token it finds
func FromFirst(extractors ...TokenExtractor) TokenExtractor {
return func(r *http.Request) (string, error) {
for _, ex := range extractors {
token, err := ex(r)
if err != nil {
return "", err
}
if token != "" {
return token, nil
}
}
return "", nil
}
}
func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
if !m.Options.EnableAuthOnOptions {
if r.Method == "OPTIONS" {
return nil
}
}
// Use the specified token extractor to extract a token from the request
token, err := m.Options.Extractor(r)
// If debugging is turned on, log the outcome
if err != nil {
m.logf("Error extracting JWT: %v", err)
} else {
m.logf("Token extracted: %s", token)
}
// If an error occurs, call the error handler and return an error
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.ValidateAndParse(token)
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// ValidateAndParse validates and parses a given access token against jwt standards and signing methods
func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.Options.CredentialsOptional {
m.logf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
m.logf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
m.logf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.Options.SigningMethod.Alg(),
parsedToken.Header["alg"])
m.logf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
m.logf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
}

View File

@ -4,10 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/status"
)
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
@ -93,9 +96,11 @@ func WriteError(err error, w http.ResponseWriter) {
httpStatus = http.StatusInternalServerError
case status.InvalidArgument:
httpStatus = http.StatusUnprocessableEntity
case status.Unauthorized:
httpStatus = http.StatusUnauthorized
default:
}
msg = err.Error()
msg = strings.ToLower(err.Error())
} else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG)

View File

@ -7,14 +7,19 @@ import (
)
const (
TokenUserProperty = "user"
AccountIDSuffix = "wt_account_id"
DomainIDSuffix = "wt_account_domain"
// TokenUserProperty key for the user property in the request context
TokenUserProperty = "user"
// AccountIDSuffix suffix for the account id claim
AccountIDSuffix = "wt_account_id"
// DomainIDSuffix suffix for the domain id claim
DomainIDSuffix = "wt_account_domain"
// DomainCategorySuffix suffix for the domain category claim
DomainCategorySuffix = "wt_account_domain_category"
UserIDClaim = "sub"
// UserIDClaim claim for the user id
UserIDClaim = "sub"
)
// Extract function type
// ExtractClaims Extract function type
type ExtractClaims func(r *http.Request) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function

View File

@ -26,7 +26,7 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
return testRequest
}

View File

@ -1,4 +1,4 @@
package middleware
package jwtclaims
import (
"bytes"
@ -17,6 +17,32 @@ import (
log "github.com/sirupsen/logrus"
)
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
@ -32,14 +58,19 @@ type JSONWebKey struct {
X5c []string `json:"x5c"`
}
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
// JWTValidator struct to handle token validation and parsing
type JWTValidator struct {
options Options
}
// NewJWTValidator constructor
func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation)
if err != nil {
return nil, err
}
return New(Options{
options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
@ -62,7 +93,59 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
},
SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false,
}), nil
}
if options.UserProperty == "" {
options.UserProperty = "user"
}
return &JWTValidator{
options: options,
}, nil
}
// ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.options.CredentialsOptional {
log.Debugf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
log.Debugf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.Debugf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
}
func getPemKeys(keysLocation string) (*Jwks, error) {

View File

@ -48,6 +48,7 @@ type MockAccountManager struct {
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error)
@ -188,6 +189,14 @@ func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *s
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
}
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
func (am *MockAccountManager) MarkPATUsed(pat string) error {
if am.MarkPATUsedFunc != nil {
return am.MarkPATUsedFunc(pat)
}
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {