refactor jwt token validation and add PAT to middleware auth

This commit is contained in:
Pascal Fischer 2023-03-30 10:54:09 +02:00
parent ecc4f8a10d
commit db3a9f0aa2
10 changed files with 341 additions and 285 deletions

View File

@ -19,25 +19,28 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. // ManagementLegacyPort is the port that was used before by the Management gRPC server.
@ -179,13 +182,22 @@ var (
tlsEnabled = true tlsEnabled = true
} }
jwtValidator, err := jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)
}
httpAPIAuthCfg := httpapi.AuthCfg{ httpAPIAuthCfg := httpapi.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer, Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience, Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim, UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation, KeysLocation: config.HttpConfig.AuthKeysLocation,
} }
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg) httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }

View File

@ -56,6 +56,7 @@ type AccountManager interface {
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error) GetPeerByKey(peerKey string) (*Peer, error)
@ -1120,6 +1121,33 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return nil return nil
} }
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
unlock := am.Store.AcquireGlobalLock()
defer unlock()
user, err := am.Store.GetUserByTokenID(tokenID)
log.Debugf("User: %v", user)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now()
am.Store.SaveAccount(account)
return nil
}
// GetAccountFromPAT returns Account and User associated with a personal access token // GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength { if len(token) != PATLength {

View File

@ -3,24 +3,25 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
pb "github.com/golang/protobuf/proto" //nolint
"strings" "strings"
"time" "time"
pb "github.com/golang/protobuf/proto" // nolint
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
internalStatus "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gRPCPeer "google.golang.org/grpc/peer" gRPCPeer "google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
internalStatus "github.com/netbirdio/netbird/management/server/status"
) )
// GRPCServer an instance of a Management gRPC API server // GRPCServer an instance of a Management gRPC API server
@ -31,7 +32,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *Config config *Config
turnCredentialsManager TURNCredentialsManager turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
} }
@ -45,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, err return nil, err
} }
var jwtMiddleware *middleware.JWTMiddleware var jwtValidator *jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtMiddleware, err = middleware.NewJwtMiddleware( jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience, config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation) config.HttpConfig.AuthKeysLocation)
@ -86,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
accountManager: accountManager, accountManager: accountManager,
config: config, config: config,
turnCredentialsManager: turnCredentialsManager, turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware, jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor, jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics, appMetrics: appMetrics,
}, nil }, nil
@ -187,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
} }
func (s *GRPCServer) validateToken(jwtToken string) (string, error) { func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtMiddleware == nil { if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt middleware set") return "", status.Error(codes.Internal, "no jwt validator set")
} }
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil { if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
} }

View File

@ -8,6 +8,7 @@ import (
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
@ -25,18 +26,17 @@ type apiHandler struct {
AuthCfg AuthCfg AuthCfg AuthCfg
} }
// EmptyObject is an empty struct used to return empty JSON object
type emptyObject struct { type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware( authMiddleware := middleware.NewAuthMiddleware(
authCfg.Issuer, accountManager.GetAccountFromPAT,
authCfg.Audience, jwtValidator.ValidateAndParse,
authCfg.KeysLocation) accountManager.MarkPATUsed,
if err != nil { authCfg.Audience)
return nil, err
}
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
@ -49,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
metricsMiddleware := appMetrics.HTTPMiddleware() metricsMiddleware := appMetrics.HTTPMiddleware()
router := rootRouter.PathPrefix("/api").Subrouter() router := rootRouter.PathPrefix("/api").Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{ api := apiHandler{
Router: router, Router: router,
@ -70,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
api.addDNSSettingEndpoint() api.addDNSSettingEndpoint()
api.addEventsEndpoint() api.addEventsEndpoint()
err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods() methods, err := route.GetMethods()
if err != nil { if err != nil {
return err return err

View File

@ -0,0 +1,170 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
type MarkPATUsedFunc func(token string) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
audience string
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware {
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
audience: audience,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (a *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := auth[0]
switch strings.ToLower(authType) {
case "bearer":
err := a.CheckJWTFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := a.CheckPATFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "No valid authentication provided"), w)
return
}
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromJWTRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
if err != nil {
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) // nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromPATRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "Token invalid"), w)
return fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.ExpirationDate) {
util.WriteError(status.Errorf(status.Unauthorized, "Token expired"), w)
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
if err != nil {
return err
}
claimMaps := jwt.MapClaims{}
claimMaps[jwtclaims.UserIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken))
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
// the PAT token from the Authorization header.
func getTokenFromPATRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
}
return authHeaderParts[1], nil
}

View File

@ -0,0 +1 @@
package middleware

View File

@ -1,249 +0,0 @@
package middleware
import (
"context"
"errors"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"log"
"net/http"
"strings"
)
// A function called whenever an error is encountered
type errorHandler func(w http.ResponseWriter, r *http.Request, err string)
// TokenExtractor is a function that takes a request as input and returns
// either a token or an error. An error should only be returned if an attempt
// to specify a token was found, but the information was somehow incorrectly
// formed. In the case where a token is simply not present, this should not
// be treated as an error. An empty string should be returned in that case.
type TokenExtractor func(r *http.Request) (string, error)
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
ErrorHandler errorHandler
// A boolean indicating if the credentials are required or not
// Default value: false
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Extractor TokenExtractor
// Debug flag turns on debugging output
// Default: false
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
type JWTMiddleware struct {
Options Options
}
func OnError(w http.ResponseWriter, r *http.Request, err string) {
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
}
// New constructs a new Secure instance with supplied options.
func New(options ...Options) *JWTMiddleware {
var opts Options
if len(options) == 0 {
opts = Options{}
} else {
opts = options[0]
}
if opts.UserProperty == "" {
opts.UserProperty = "user"
}
if opts.ErrorHandler == nil {
opts.ErrorHandler = OnError
}
if opts.Extractor == nil {
opts.Extractor = FromAuthHeader
}
return &JWTMiddleware{
Options: opts,
}
}
func (m *JWTMiddleware) logf(format string, args ...interface{}) {
if m.Options.Debug {
log.Printf(format, args...)
}
}
// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere.
func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not call next.
if err == nil && next != nil {
next(w, r)
}
}
func (m *JWTMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Let secure process the request. If it returns an error,
// that indicates the request should not continue.
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not continue.
if err != nil {
return
}
h.ServeHTTP(w, r)
})
}
// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func FromAuthHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// FromParameter returns a function that extracts the token from the specified
// query string parameter
func FromParameter(param string) TokenExtractor {
return func(r *http.Request) (string, error) {
return r.URL.Query().Get(param), nil
}
}
// FromFirst returns a function that runs multiple token extractors and takes the
// first token it finds
func FromFirst(extractors ...TokenExtractor) TokenExtractor {
return func(r *http.Request) (string, error) {
for _, ex := range extractors {
token, err := ex(r)
if err != nil {
return "", err
}
if token != "" {
return token, nil
}
}
return "", nil
}
}
func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
if !m.Options.EnableAuthOnOptions {
if r.Method == "OPTIONS" {
return nil
}
}
// Use the specified token extractor to extract a token from the request
token, err := m.Options.Extractor(r)
// If debugging is turned on, log the outcome
if err != nil {
m.logf("Error extracting JWT: %v", err)
} else {
m.logf("Token extracted: %s", token)
}
// If an error occurs, call the error handler and return an error
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.ValidateAndParse(token)
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// ValidateAndParse validates and parses a given access token against jwt standards and signing methods
func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.Options.CredentialsOptional {
m.logf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
m.logf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
m.logf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.Options.SigningMethod.Alg(),
parsedToken.Header["alg"])
m.logf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
m.logf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
}

View File

@ -4,10 +4,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http" "net/http"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/status"
) )
// WriteJSONObject simply writes object to the HTTP reponse in JSON format // WriteJSONObject simply writes object to the HTTP reponse in JSON format
@ -93,6 +95,8 @@ func WriteError(err error, w http.ResponseWriter) {
httpStatus = http.StatusInternalServerError httpStatus = http.StatusInternalServerError
case status.InvalidArgument: case status.InvalidArgument:
httpStatus = http.StatusUnprocessableEntity httpStatus = http.StatusUnprocessableEntity
case status.Unauthorized:
httpStatus = http.StatusUnauthorized
default: default:
} }
msg = err.Error() msg = err.Error()

View File

@ -1,4 +1,4 @@
package middleware package jwtclaims
import ( import (
"bytes" "bytes"
@ -17,6 +17,32 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct { type Jwks struct {
Keys []JSONWebKey `json:"keys"` Keys []JSONWebKey `json:"keys"`
@ -32,14 +58,17 @@ type JSONWebKey struct {
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header type JWTValidator struct {
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { options Options
}
func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return New(Options{ options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim // Verify 'aud' claim
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
@ -62,7 +91,58 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
}, },
SigningMethod: jwt.SigningMethodRS256, SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false, EnableAuthOnOptions: false,
}), nil }
if options.UserProperty == "" {
options.UserProperty = "user"
}
return &JWTValidator{
options: options,
}, nil
}
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.options.CredentialsOptional {
log.Debugf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
log.Debugf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.Debugf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
} }
func getPemKeys(keysLocation string) (*Jwks, error) { func getPemKeys(keysLocation string) (*Jwks, error) {

View File

@ -48,6 +48,7 @@ type MockAccountManager struct {
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error)
@ -186,6 +187,14 @@ func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *s
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
} }
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
func (am *MockAccountManager) MarkPATUsed(pat string) error {
if am.MarkPATUsedFunc != nil {
return am.MarkPATUsedFunc(pat)
}
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
}
// AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface // AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface
func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error {
if am.AddPATToUserFunc != nil { if am.AddPATToUserFunc != nil {