netbird/management/server/http/middleware/auth_middleware.go
pascal-fischer f89c200ce9
Fix api Auth with PAT when a custom UserIDClaim is configured in management.json (#1120)
The API authentication with PATs was not considering different userIDClaim 
that some of the IdPs are using.
In this PR we read the userIDClaim from the config file 
instead of using the fixed default and only keep 
it as a fallback if none in defined.
2023-09-01 18:09:59 +02:00

179 lines
5.6 KiB
Go

package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
audience string
userIDClaim string
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
audience: audience,
userIDClaim: userIdClaim,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := auth[0]
switch strings.ToLower(authType) {
case "bearer":
err := m.CheckJWTFromRequest(w, r)
if err != nil {
log.Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.CheckPATFromRequest(w, r)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromJWTRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
if err != nil {
return err
}
if validatedToken == nil {
return nil
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
token, err := getTokenFromPATRequest(r)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.ExpirationDate) {
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
if err != nil {
return err
}
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
// the PAT token from the Authorization header.
func getTokenFromPATRequest(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return "", nil // No error, just no token
}
// TODO: Make this a bit more robust, parsing-wise
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
}
return authHeaderParts[1], nil
}