Merge pull request #774 from netbirdio/feature/add_pat_middleware

Feature/add pat middleware
This commit is contained in:
pascal-fischer 2023-04-03 12:09:11 +02:00 committed by GitHub
commit c54fb9643c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 537 additions and 295 deletions

View File

@ -19,25 +19,28 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. // ManagementLegacyPort is the port that was used before by the Management gRPC server.
@ -179,13 +182,22 @@ var (
tlsEnabled = true tlsEnabled = true
} }
jwtValidator, err := jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)
}
httpAPIAuthCfg := httpapi.AuthCfg{ httpAPIAuthCfg := httpapi.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer, Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience, Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim, UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation, KeysLocation: config.HttpConfig.AuthKeysLocation,
} }
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg) httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }

View File

@ -56,6 +56,7 @@ type AccountManager interface {
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error) GetPeerByKey(peerKey string) (*Peer, error)
@ -1122,6 +1123,39 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
return nil return nil
} }
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
unlock := am.Store.AcquireGlobalLock()
user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
unlock()
unlock = am.Store.AcquireAccountLock(account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now()
return am.Store.SaveAccount(account)
}
// GetAccountFromPAT returns Account and User associated with a personal access token // GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength { if len(token) != PATLength {

View File

@ -495,6 +495,44 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
} }
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
}
err = am.MarkPATUsed("tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount("account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
}
func TestAccountManager_PrivateAccount(t *testing.T) { func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
if err != nil { if err != nil {

View File

@ -10,7 +10,6 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
@ -33,7 +32,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *Config config *Config
turnCredentialsManager TURNCredentialsManager turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
} }
@ -47,10 +46,10 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, err return nil, err
} }
var jwtMiddleware *middleware.JWTMiddleware var jwtValidator *jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtMiddleware, err = middleware.NewJwtMiddleware( jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience, config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation) config.HttpConfig.AuthKeysLocation)
@ -88,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
accountManager: accountManager, accountManager: accountManager,
config: config, config: config,
turnCredentialsManager: turnCredentialsManager, turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware, jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor, jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics, appMetrics: appMetrics,
}, nil }, nil
@ -189,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
} }
func (s *GRPCServer) validateToken(jwtToken string) (string, error) { func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtMiddleware == nil { if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt middleware set") return "", status.Error(codes.Internal, "no jwt validator set")
} }
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken) token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil { if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
} }

View File

@ -8,6 +8,7 @@ import (
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
@ -25,18 +26,17 @@ type apiHandler struct {
AuthCfg AuthCfg AuthCfg AuthCfg
} }
// EmptyObject is an empty struct used to return empty JSON object
type emptyObject struct { type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware( authMiddleware := middleware.NewAuthMiddleware(
authCfg.Issuer, accountManager.GetAccountFromPAT,
authCfg.Audience, jwtValidator.ValidateAndParse,
authCfg.KeysLocation) accountManager.MarkPATUsed,
if err != nil { authCfg.Audience)
return nil, err
}
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
@ -49,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
metricsMiddleware := appMetrics.HTTPMiddleware() metricsMiddleware := appMetrics.HTTPMiddleware()
router := rootRouter.PathPrefix("/api").Subrouter() router := rootRouter.PathPrefix("/api").Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{ api := apiHandler{
Router: router, Router: router,
@ -70,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
api.addDNSSettingEndpoint() api.addDNSSettingEndpoint()
api.addEventsEndpoint() api.addEventsEndpoint()
err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods() methods, err := route.GetMethods()
if err != nil { if err != nil {
return err return err

View File

@ -2,6 +2,9 @@ package middleware
import ( import (
"net/http" "net/http"
"regexp"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -39,10 +42,22 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return
} }
if !ok { if !ok {
switch r.Method { switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
if err != nil {
log.Debugf("Regex failed")
util.WriteError(status.Errorf(status.Internal, ""), w)
return
}
if ok {
log.Debugf("Valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return return
} }

View File

@ -0,0 +1,173 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
audience string
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware {
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
audience: audience,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := auth[0]
switch strings.ToLower(authType) {
case "bearer":
err := m.CheckJWTFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.CheckPATFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromJWTRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
if err != nil {
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromPATRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.ExpirationDate) {
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
if err != nil {
return err
}
claimMaps := jwt.MapClaims{}
claimMaps[jwtclaims.UserIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
// the PAT token from the Authorization header.
func getTokenFromPATRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
}
return authHeaderParts[1], nil
}

View File

@ -0,0 +1,123 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
)
const (
audience = "audience"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
Id: accountID,
Domain: domain,
Users: map[string]*server.User{
userID: {
Id: userID,
PATs: map[string]*server.PersonalAccessToken{
tokenID: {
ID: tokenID,
Name: "My first token",
HashedToken: "someHash",
ExpirationDate: time.Now().AddDate(0, 0, 7),
CreatedBy: userID,
CreatedAt: time.Now(),
LastUsed: time.Now(),
},
},
},
},
}
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{}, nil
}
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct {
name string
authHeader string
expectedStatusCode int
}{
{
name: "Valid PAT Token",
authHeader: "Token " + PAT,
expectedStatusCode: 200,
},
{
name: "Invalid PAT Token",
authHeader: "Token " + wrongToken,
expectedStatusCode: 401,
},
{
name: "Valid JWT Token",
authHeader: "Bearer " + JWT,
expectedStatusCode: 200,
},
{
name: "Invalid JWT Token",
authHeader: "Bearer " + wrongToken,
expectedStatusCode: 401,
},
{
name: "Basic Auth",
authHeader: "Basic " + PAT,
expectedStatusCode: 401,
},
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// do nothing
})
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience)
handlerToTest := authMiddleware.Handler(nextHandler)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://testing", nil)
req.Header.Set("Authorization", tc.authHeader)
rec := httptest.NewRecorder()
handlerToTest.ServeHTTP(rec, req)
if rec.Result().StatusCode != tc.expectedStatusCode {
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode)
}
})
}
}

View File

@ -1,254 +0,0 @@
package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
)
// A function called whenever an error is encountered
type errorHandler func(w http.ResponseWriter, r *http.Request, err string)
// TokenExtractor is a function that takes a request as input and returns
// either a token or an error. An error should only be returned if an attempt
// to specify a token was found, but the information was somehow incorrectly
// formed. In the case where a token is simply not present, this should not
// be treated as an error. An empty string should be returned in that case.
type TokenExtractor func(r *http.Request) (string, error)
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
ErrorHandler errorHandler
// A boolean indicating if the credentials are required or not
// Default value: false
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Extractor TokenExtractor
// Debug flag turns on debugging output
// Default: false
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
type JWTMiddleware struct {
Options Options
}
func OnError(w http.ResponseWriter, r *http.Request, err string) {
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
}
// New constructs a new Secure instance with supplied options.
func New(options ...Options) *JWTMiddleware {
var opts Options
if len(options) == 0 {
opts = Options{}
} else {
opts = options[0]
}
if opts.UserProperty == "" {
opts.UserProperty = "user"
}
if opts.ErrorHandler == nil {
opts.ErrorHandler = OnError
}
if opts.Extractor == nil {
opts.Extractor = FromAuthHeader
}
return &JWTMiddleware{
Options: opts,
}
}
func (m *JWTMiddleware) logf(format string, args ...interface{}) {
if m.Options.Debug {
log.Printf(format, args...)
}
}
// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere.
func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not call next.
if err == nil && next != nil {
next(w, r)
}
}
func (m *JWTMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Let secure process the request. If it returns an error,
// that indicates the request should not continue.
err := m.CheckJWTFromRequest(w, r)
// If there was an error, do not continue.
if err != nil {
log.Errorf("received an error while validating the JWT token: %s. "+
"Review your IDP configuration and ensure that "+
"settings are in sync between dashboard and management", err)
return
}
h.ServeHTTP(w, r)
})
}
// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func FromAuthHeader(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// FromParameter returns a function that extracts the token from the specified
// query string parameter
func FromParameter(param string) TokenExtractor {
return func(r *http.Request) (string, error) {
return r.URL.Query().Get(param), nil
}
}
// FromFirst returns a function that runs multiple token extractors and takes the
// first token it finds
func FromFirst(extractors ...TokenExtractor) TokenExtractor {
return func(r *http.Request) (string, error) {
for _, ex := range extractors {
token, err := ex(r)
if err != nil {
return "", err
}
if token != "" {
return token, nil
}
}
return "", nil
}
}
func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
if !m.Options.EnableAuthOnOptions {
if r.Method == "OPTIONS" {
return nil
}
}
// Use the specified token extractor to extract a token from the request
token, err := m.Options.Extractor(r)
// If debugging is turned on, log the outcome
if err != nil {
m.logf("Error extracting JWT: %v", err)
} else {
m.logf("Token extracted: %s", token)
}
// If an error occurs, call the error handler and return an error
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.ValidateAndParse(token)
if err != nil {
m.Options.ErrorHandler(w, r, err.Error())
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// ValidateAndParse validates and parses a given access token against jwt standards and signing methods
func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.Options.CredentialsOptional {
m.logf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
m.logf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
m.logf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.Options.SigningMethod.Alg(),
parsedToken.Header["alg"])
m.logf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
m.logf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
}

View File

@ -4,10 +4,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http" "net/http"
"strings"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/status"
) )
// WriteJSONObject simply writes object to the HTTP reponse in JSON format // WriteJSONObject simply writes object to the HTTP reponse in JSON format
@ -93,9 +96,11 @@ func WriteError(err error, w http.ResponseWriter) {
httpStatus = http.StatusInternalServerError httpStatus = http.StatusInternalServerError
case status.InvalidArgument: case status.InvalidArgument:
httpStatus = http.StatusUnprocessableEntity httpStatus = http.StatusUnprocessableEntity
case status.Unauthorized:
httpStatus = http.StatusUnauthorized
default: default:
} }
msg = err.Error() msg = strings.ToLower(err.Error())
} else { } else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG) log.Error(unhandledMSG)

View File

@ -7,14 +7,19 @@ import (
) )
const ( const (
// TokenUserProperty key for the user property in the request context
TokenUserProperty = "user" TokenUserProperty = "user"
// AccountIDSuffix suffix for the account id claim
AccountIDSuffix = "wt_account_id" AccountIDSuffix = "wt_account_id"
// DomainIDSuffix suffix for the domain id claim
DomainIDSuffix = "wt_account_domain" DomainIDSuffix = "wt_account_domain"
// DomainCategorySuffix suffix for the domain category claim
DomainCategorySuffix = "wt_account_domain_category" DomainCategorySuffix = "wt_account_domain_category"
// UserIDClaim claim for the user id
UserIDClaim = "sub" UserIDClaim = "sub"
) )
// Extract function type // ExtractClaims Extract function type
type ExtractClaims func(r *http.Request) AuthorizationClaims type ExtractClaims func(r *http.Request) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function // ClaimsExtractor struct that holds the extract function

View File

@ -1,4 +1,4 @@
package middleware package jwtclaims
import ( import (
"bytes" "bytes"
@ -17,6 +17,32 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
// Default: nil
SigningMethod jwt.SigningMethod
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct { type Jwks struct {
Keys []JSONWebKey `json:"keys"` Keys []JSONWebKey `json:"keys"`
@ -32,14 +58,19 @@ type JSONWebKey struct {
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header // JWTValidator struct to handle token validation and parsing
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { type JWTValidator struct {
options Options
}
// NewJWTValidator constructor
func NewJWTValidator(issuer string, audience string, keysLocation string) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return New(Options{ options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim // Verify 'aud' claim
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
@ -62,7 +93,59 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
}, },
SigningMethod: jwt.SigningMethodRS256, SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false, EnableAuthOnOptions: false,
}), nil }
if options.UserProperty == "" {
options.UserProperty = "user"
}
return &JWTValidator{
options: options,
}, nil
}
// ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.options.CredentialsOptional {
log.Debugf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
log.Debugf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.Debugf(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
} }
func getPemKeys(keysLocation string) (*Jwks, error) { func getPemKeys(keysLocation string) (*Jwks, error) {

View File

@ -48,6 +48,7 @@ type MockAccountManager struct {
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error)
@ -188,6 +189,14 @@ func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *s
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
} }
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
func (am *MockAccountManager) MarkPATUsed(pat string) error {
if am.MarkPATUsedFunc != nil {
return am.MarkPATUsedFunc(pat)
}
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface // CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil { if am.CreatePATFunc != nil {