Feature: add custom id claim (#667)

This feature allows using the custom claim in the JWT token as a user ID.

Refactor claims extractor with options support

Add is_current to the user API response
This commit is contained in:
Givi Khojanashvili 2023-02-04 00:47:20 +04:00 committed by GitHub
parent 494e56d1be
commit 3ec8274b8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 474 additions and 305 deletions

View File

@ -30,7 +30,7 @@ services:
- $SIGNAL_VOLUMENAME:/var/lib/netbird
ports:
- 10000:80
# # port and command for Let's Encrypt validation
# # port and command for Let's Encrypt validation
# - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
# Management
@ -45,7 +45,7 @@ services:
- ./management.json:/etc/netbird/management.json
ports:
- $NETBIRD_MGMT_API_PORT:443 #API port
# # command for Let's Encrypt validation without dashboard container
# # command for Let's Encrypt validation without dashboard container
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"]
# Coturn
@ -60,7 +60,6 @@ services:
network_mode: host
command:
- -c /etc/turnserver.conf
volumes:
$MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME:

View File

@ -32,6 +32,7 @@
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
@ -49,4 +50,4 @@
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT"
}
}
}
}

View File

@ -7,6 +7,8 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
NETBIRD_AUTH_AUDIENCE=""
# e.g. netbird-client
NETBIRD_AUTH_CLIENT_ID=""
# if you want to use a custom claim for the user ID instead of 'sub', set it here
# NETBIRD_AUTH_USER_ID_CLAIM=""
# indicates whether to use Auth0 or not: true or false
NETBIRD_USE_AUTH0="false"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"

View File

@ -7,15 +7,6 @@ import (
"errors"
"flag"
"fmt"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"io"
"io/fs"
"net"
@ -26,6 +17,16 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util"
@ -178,8 +179,13 @@ var (
tlsEnabled = true
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics)
httpAPIAuthCfg := httpapi.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
@ -415,7 +421,6 @@ type OIDCConfigResponse struct {
// fetchOIDCConfig fetches OIDC configuration from the IDP
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint)
if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
@ -445,7 +450,6 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
}
return config, nil
}
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {

View File

@ -3,6 +3,15 @@ package server
import (
"context"
"fmt"
"math/rand"
"net"
"net/netip"
"reflect"
"regexp"
"strings"
"sync"
"time"
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
nbdns "github.com/netbirdio/netbird/dns"
@ -14,14 +23,6 @@ import (
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"math/rand"
"net"
"net/netip"
"reflect"
"regexp"
"strings"
"sync"
"time"
)
const (
@ -219,7 +220,6 @@ func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Rou
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
@ -243,7 +243,6 @@ func (a *Account) GetPeerByIP(peerIP string) *Peer {
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
@ -490,7 +489,8 @@ func (a *Account) GetPeer(peerID string) *Peer {
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store) (*DefaultAccountManager, error) {
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
peersUpdateManager: peersUpdateManager,
@ -544,14 +544,13 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
err := am.warmupIDPCache()
if err != nil {
log.Warnf("failed warming up cache due to error: %v", err)
//todo retry?
// todo retry?
return
}
}()
}
return am, nil
}
// newAccount creates a new Account with a generated ID and generated default setup keys.
@ -669,7 +668,6 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
}
return nil, nil
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
@ -768,7 +766,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool) error {
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
@ -826,6 +825,9 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
var (
account *Account
err error
@ -897,7 +899,9 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
@ -943,6 +947,9 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
@ -995,7 +1002,6 @@ func isDomainValid(domain string) bool {
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()

View File

@ -7,8 +7,13 @@ import (
"github.com/netbirdio/netbird/util"
)
type Protocol string
type Provider string
type (
// Protocol type
Protocol string
// Provider authorization flow type
Provider string
)
const (
UDP Protocol = "udp"
@ -45,14 +50,16 @@ type TURNConfig struct {
// HttpServerConfig is a config of the HTTP Management service server
type HttpServerConfig struct {
LetsEncryptDomain string
//CertFile is the location of the certificate
// CertFile is the location of the certificate
CertFile string
//CertKey is the location of the certificate private key
// CertKey is the location of the certificate private key
CertKey string
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
AuthAudience string
// AuthIssuer identifies principal that issued the JWT.
// AuthIssuer identifies principal that issued the JWT
AuthIssuer string
// AuthUserIDClaim is the name of the claim that used as user ID
AuthUserIDClaim string
// AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT
AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration

View File

@ -3,11 +3,11 @@ package server
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
gPeer "google.golang.org/grpc/peer"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -31,12 +31,14 @@ type GRPCServer struct {
config *Config
turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) {
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics,
) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@ -66,6 +68,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
}
}
var audience, userIDClaim string
if config.HttpConfig != nil {
audience = config.HttpConfig.AuthAudience
userIDClaim = config.HttpConfig.AuthUserIDClaim
}
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@ -74,6 +86,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
}, nil
}
@ -113,7 +126,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, err := s.accountManager.GetPeerByKey(peerKey.String())
if err != nil {
p, _ := gPeer.FromContext(srv.Context())
p, _ := gRPCPeer.FromContext(srv.Context())
msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return msg
@ -122,7 +135,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
if err != nil {
p, _ := gPeer.FromContext(srv.Context())
p, _ := gRPCPeer.FromContext(srv.Context())
msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return msg
@ -200,7 +213,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
if err != nil {
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
}
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
claims := s.jwtClaimsExtractor.FromToken(token)
userID = claims.UserId
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims)
@ -305,7 +318,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
// peer doesn't exist -> check if setup key was provided
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
// absent setup key or jwt -> permission denied
p, _ := gPeer.FromContext(ctx)
p, _ := gRPCPeer.FromContext(ctx)
msg := status.Errorf(codes.PermissionDenied,
"provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided,"+
" remote addr is %s", peerKey.String(), p.Addr.String())

View File

@ -46,6 +46,10 @@ components:
type: array
items:
type: string
is_current:
description: Is true if authenticated user is the same as this user
type: boolean
readOnly: true
required:
- id
- email
@ -1703,4 +1707,4 @@ paths:
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
"$ref": "#/components/responses/internal_error"

View File

@ -566,6 +566,9 @@ type User struct {
// Id User ID
Id string `json:"id"`
// IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"`
// Name User's name from idp provider
Name string `json:"name"`

View File

@ -2,33 +2,35 @@ package http
import (
"encoding/json"
"net/http"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"net/http"
)
// DNSSettings is a handler that returns the DNS settings of the account
type DNSSettings struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewDNSSettings returns a new instance of DNSSettings handler
func NewDNSSettings(accountManager server.AccountManager, authAudience string) *DNSSettings {
func NewDNSSettings(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettings {
return &DNSSettings{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
@ -51,7 +53,7 @@ func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettings) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)

View File

@ -3,14 +3,15 @@ package http
import (
"bytes"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -52,16 +53,15 @@ func initDNSSettingsTestData() *DNSSettings {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testDNSSettingsAccountID,
}
},
},
}),
),
}
}

View File

@ -2,34 +2,36 @@ package http
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"net/http"
)
// Events HTTP handler
type Events struct {
accountManager server.AccountManager
authAudience string
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewEvents creates a new Events HTTP handler
func NewEvents(accountManager server.AccountManager, authAudience string) *Events {
func NewEvents(accountManager server.AccountManager, authCfg AuthCfg) *Events {
return &Events{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetEvents list of the given account
func (h *Events) GetEvents(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)

View File

@ -2,6 +2,13 @@ package http
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
@ -9,12 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
)
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *Events {
@ -36,16 +37,15 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}, user, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",
}
},
},
}),
),
}
}
@ -244,7 +244,6 @@ func TestEvents_GetEvents(t *testing.T) {
assert.Equal(t, expected.Meta["some"], event.Meta["some"])
assert.True(t, expected.Timestamp.Equal(event.Timestamp))
}
})
}
}

View File

@ -2,10 +2,11 @@ package http
import (
"encoding/json"
"net/http"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -17,22 +18,23 @@ import (
// Groups is a handler that returns groups of the account
type Groups struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
func NewGroups(accountManager server.AccountManager, authCfg AuthCfg) *Groups {
return &Groups{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
@ -50,7 +52,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -119,7 +121,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -223,7 +225,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -265,7 +267,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -301,7 +303,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)

View File

@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"io"
"net"
"net/http"
@ -13,6 +11,9 @@ import (
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -78,20 +79,20 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
},
Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
"id-all": {ID: "id-all", Name: "All"},
},
}, user, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}),
),
}
}
@ -270,7 +271,8 @@ func TestWriteGroup(t *testing.T) {
PeersCount: 2,
Peers: []api.PeerMinimum{
{Id: "peer-A-ID"},
{Id: "peer-B-ID"}},
{Id: "peer-B-ID"},
},
},
},
}

View File

@ -1,22 +1,29 @@
package http
import (
"net/http"
"github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/cors"
"net/http"
)
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string,
appMetrics telemetry.AppMetrics) (http.Handler, error) {
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware(
authIssuer,
authAudience,
authKeysLocation,
)
authCfg.Issuer,
authCfg.Audience,
authCfg.KeysLocation)
if err != nil {
return nil, err
}
@ -24,7 +31,8 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControl(
authAudience,
authCfg.Audience,
authCfg.UserIDClaim,
accountManager.IsUserAdmin)
rootRouter := mux.NewRouter()
@ -33,15 +41,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler := rootRouter.PathPrefix("/api").Subrouter()
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
groupsHandler := NewGroups(accountManager, authAudience)
rulesHandler := NewRules(accountManager, authAudience)
peersHandler := NewPeers(accountManager, authAudience)
keysHandler := NewSetupKeysHandler(accountManager, authAudience)
userHandler := NewUserHandler(accountManager, authAudience)
routesHandler := NewRoutes(accountManager, authAudience)
nameserversHandler := NewNameservers(accountManager, authAudience)
eventsHandler := NewEvents(accountManager, authAudience)
dnsSettingsHandler := NewDNSSettings(accountManager, authAudience)
groupsHandler := NewGroups(accountManager, authCfg)
rulesHandler := NewRules(accountManager, authCfg)
peersHandler := NewPeers(accountManager, authCfg)
keysHandler := NewSetupKeysHandler(accountManager, authCfg)
userHandler := NewUserHandler(accountManager, authCfg)
routesHandler := NewRoutes(accountManager, authCfg)
nameserversHandler := NewNameservers(accountManager, authCfg)
eventsHandler := NewEvents(accountManager, authCfg)
dnsSettingsHandler := NewDNSSettings(accountManager, authCfg)
apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
@ -88,7 +96,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS")
err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
err = apiHandler.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods()
if err != nil {
return err
@ -110,5 +118,4 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
}
return rootRouter, nil
}

View File

@ -1,9 +1,10 @@
package middleware
import (
"net/http"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
@ -12,17 +13,18 @@ type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
jwtExtractor jwtclaims.ClaimsExtractor
isUserAdmin IsUserAdminFunc
audience string
isUserAdmin IsUserAdminFunc
claimsExtract jwtclaims.ClaimsExtractor
}
// NewAccessControl instance constructor
func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl {
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
return &AccessControl{
isUserAdmin: isUserAdmin,
audience: audience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
isUserAdmin: isUserAdmin,
claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
),
}
}
@ -30,9 +32,9 @@ func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessContr
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience)
claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(jwtClaims)
ok, err := a.isUserAdmin(claims)
if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
@ -40,7 +42,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
if !ok {
switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return

View File

@ -10,10 +10,11 @@ import (
"encoding/pem"
"errors"
"fmt"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"math/big"
"net/http"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
)
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
@ -33,7 +34,6 @@ type JSONWebKey struct {
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
keys, err := getPemKeys(keysLocation)
if err != nil {
return nil, err
@ -67,13 +67,12 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var jwks = &Jwks{}
jwks := &Jwks{}
err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil {
return jwks, err

View File

@ -3,6 +3,8 @@ package http
import (
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/mux"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
@ -11,28 +13,28 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http"
)
// Nameservers is the nameserver group handler of the account
type Nameservers struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewNameservers returns a new instance of Nameservers handler
func NewNameservers(accountManager server.AccountManager, authAudience string) *Nameservers {
func NewNameservers(accountManager server.AccountManager, authCfg AuthCfg) *Nameservers {
return &Nameservers{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
@ -56,7 +58,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -89,7 +91,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -139,7 +141,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -221,7 +223,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -245,7 +247,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
@ -268,7 +270,6 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp)
}
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {

View File

@ -3,16 +3,17 @@ package http
import (
"bytes"
"encoding/json"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -113,16 +114,15 @@ func initNameserversTestData() *Nameservers {
return testingNSAccount, testingAccount.Users["test_user"], nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testNSGroupAccountID,
}
},
},
}),
),
}
}

View File

@ -3,27 +3,29 @@ package http
import (
"encoding/json"
"fmt"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
)
// Peers is a handler that returns peers of the account
type Peers struct {
accountManager server.AccountManager
authAudience string
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func NewPeers(accountManager server.AccountManager, authAudience string) *Peers {
func NewPeers(accountManager server.AccountManager, authCfg AuthCfg) *Peers {
return &Peers{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
@ -55,7 +57,7 @@ func (h *Peers) deletePeer(accountID, userID string, peerID string, w http.Respo
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -78,13 +80,12 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)

View File

@ -2,13 +2,14 @@ package http
import (
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
@ -36,16 +37,15 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
}, user, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}),
),
}
}

View File

@ -2,6 +2,9 @@ package http
import (
"encoding/json"
"net/http"
"unicode/utf8"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
@ -9,29 +12,28 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"net/http"
"unicode/utf8"
)
// Routes is the routes handler of the account
type Routes struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewRoutes returns a new instance of Routes handler
func NewRoutes(accountManager server.AccountManager, authAudience string) *Routes {
func NewRoutes(accountManager server.AccountManager, authCfg AuthCfg) *Routes {
return &Routes{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -53,7 +55,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -92,7 +94,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -158,7 +160,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -299,7 +301,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -323,7 +325,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)

View File

@ -4,9 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"io"
"net/http"
"net/http/httptest"
@ -14,6 +11,10 @@ import (
"strconv"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
@ -142,16 +143,15 @@ func initRoutesTestData() *Routes {
return testingAccount, testingAccount.Users["test_user"], nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,
}
},
},
}),
),
}
}

View File

@ -2,6 +2,8 @@ package http
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
@ -9,27 +11,27 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
"net/http"
)
// Rules is a handler that returns rules of the account
type Rules struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func NewRules(accountManager server.AccountManager, authAudience string) *Rules {
func NewRules(accountManager server.AccountManager, authCfg AuthCfg) *Rules {
return &Rules{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -51,7 +53,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -122,7 +124,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -253,7 +255,6 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
if err != nil {
util.WriteError(err, w)
return
@ -266,7 +267,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -325,7 +326,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -350,7 +351,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)

View File

@ -4,13 +4,14 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -71,7 +72,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}},
Rules: map[string]*server.Rule{"id-existed": {ID: "id-existed"}},
Groups: map[string]*server.Group{
"F": {ID: "F"},
"G": {ID: "G"},
@ -82,16 +83,15 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}, user, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}),
),
}
}
@ -264,7 +264,8 @@ func TestRulesWriteRule(t *testing.T) {
Flow: server.TrafficFlowBidirectString,
Sources: []api.GroupMinimum{
{Id: "G"},
{Id: "F"}},
{Id: "F"},
},
},
},
}
@ -306,7 +307,6 @@ func TestRulesWriteRule(t *testing.T) {
}
assert.Equal(t, got, tc.expectedRule)
})
}
}

View File

@ -2,34 +2,36 @@ package http
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"time"
)
// SetupKeys is a handler that returns a list of setup keys of the account
type SetupKeys struct {
accountManager server.AccountManager
jwtExtractor jwtclaims.ClaimsExtractor
authAudience string
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func NewSetupKeysHandler(accountManager server.AccountManager, authAudience string) *SetupKeys {
func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeys {
return &SetupKeys{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -72,7 +74,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -97,7 +99,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -144,8 +146,7 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)

View File

@ -4,16 +4,17 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server"
@ -28,7 +29,8 @@ const (
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
user *server.User,
) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
@ -43,11 +45,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
},
Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
"id-all": {ID: "id-all", Name: "All"},
},
}, user, nil
},
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string) (*server.SetupKey, error) {
_ int, _ string,
) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
return newKey, nil
}
@ -75,16 +79,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
return []*server.SetupKey{defaultKey}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: user.Id,
Domain: "hotmail.com",
AccountId: testAccountID,
}
},
},
}),
),
}
}
@ -209,7 +212,6 @@ func TestSetupKeysHandlers(t *testing.T) {
assertKeys(t, got[0], tc.expectedSetupKeys[0])
return
}
})
}
}

View File

@ -2,27 +2,29 @@ package http
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
type UserHandler struct {
accountManager server.AccountManager
authAudience string
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func NewUserHandler(accountManager server.AccountManager, authAudience string) *UserHandler {
func NewUserHandler(accountManager server.AccountManager, authCfg AuthCfg) *UserHandler {
return &UserHandler{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
@ -33,7 +35,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -69,8 +71,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toUserResponse(newUser))
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
}
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
@ -80,7 +81,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
return
}
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -109,7 +110,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toUserResponse(newUser))
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
}
// GetUsers returns a list of users of the account this user belongs to.
@ -120,7 +121,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
@ -135,14 +136,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
users := make([]*api.User, 0)
for _, r := range data {
users = append(users, toUserResponse(r))
users = append(users, toUserResponse(r, claims.UserId))
}
util.WriteJSONObject(w, users)
}
func toUserResponse(user *server.UserInfo) *api.User {
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
autoGroups := user.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
@ -158,6 +158,7 @@ func toUserResponse(user *server.UserInfo) *api.User {
userStatus = api.UserStatusDisabled
}
isCurrent := user.ID == currenUserID
return &api.User{
Id: user.ID,
Name: user.Name,
@ -165,5 +166,6 @@ func toUserResponse(user *server.UserInfo) *api.User {
Role: user.Role,
AutoGroups: autoGroups,
Status: userStatus,
IsCurrent: &isCurrent,
}
}

View File

@ -40,16 +40,15 @@ func initUsers(user ...*server.User) *UserHandler {
return users, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "1",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}),
),
}
}
@ -57,7 +56,7 @@ func TestGetUsers(t *testing.T) {
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}}
userHandler := initUsers(users...)
var tt = []struct {
tt := []struct {
name string
expectedStatus int
requestType string

View File

@ -1,8 +1,9 @@
package jwtclaims
import (
"github.com/golang-jwt/jwt"
"net/http"
"github.com/golang-jwt/jwt"
)
const (
@ -14,51 +15,85 @@ const (
)
// Extract function type
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
type ExtractClaims func(r *http.Request) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
ExtractClaimsFromRequestContext ExtractClaims
authAudience string
userIDClaim string
FromRequestContext ExtractClaims
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
type ClaimsExtractorOption func(*ClaimsExtractor)
// WithAudience sets the audience for the extractor
func WithAudience(audience string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.authAudience = audience
}
}
// WithUserIDClaim sets the user id claim for the extractor
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.userIDClaim = userIDClaim
}
}
// WithFromRequestContext sets the function that extracts claims from the request context
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.FromRequestContext = ec
}
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
var extractFunc ExtractClaims
if extractFunc = e; extractFunc == nil {
extractFunc = ExtractClaimsFromRequestContext
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
ce := &ClaimsExtractor{}
for _, option := range options {
option(ce)
}
return &ClaimsExtractor{
ExtractClaimsFromRequestContext: extractFunc,
if ce.FromRequestContext == nil {
ce.FromRequestContext = ce.fromRequestContext
}
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
}
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func ExtractClaimsFromRequestContext(r *http.Request, authAudience string) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return ExtractClaimsWithToken(token, authAudience)
}
// ExtractClaimsWithToken extracts claims from the token (after auth)
func ExtractClaimsWithToken(token *jwt.Token, authAudience string) AuthorizationClaims {
// FromToken extracts claims from the token (after auth)
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
jwtClaims.UserId = claims[UserIDClaim].(string)
accountIdClaim, ok := claims[authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIdClaim.(string)
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return jwtClaims
}
domainClaim, ok := claims[authAudience+DomainIDSuffix]
jwtClaims.UserId = userID
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIDClaim.(string)
}
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
domainCategoryClaim, ok := claims[authAudience+DomainCategorySuffix]
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
if ok {
jwtClaims.DomainCategory = domainCategoryClaim.(string)
}
return jwtClaims
}
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return c.FromToken(token)
}

View File

@ -2,10 +2,11 @@ package jwtclaims
import (
"context"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
)
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
@ -31,7 +32,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
}
func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct {
name string
inputAuthorizationClaims AuthorizationClaims
@ -99,12 +99,84 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
extractedClaims := extractor.FromRequestContext(request)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
})
}
}
func TestExtractClaimsSetOptions(t *testing.T) {
type test struct {
name string
extractor *ClaimsExtractor
check func(t *testing.T, c test)
}
testCase1 := test{
name: "No custom options",
extractor: NewClaimsExtractor(),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "" {
t.Error("audience should be empty")
return
}
if c.extractor.userIDClaim != UserIDClaim {
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
return
}
if c.extractor.FromRequestContext == nil {
t.Error("from request context should not be nil")
return
}
},
}
testCase2 := test{
name: "Custom audience",
extractor: NewClaimsExtractor(WithAudience("https://login/")),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "https://login/" {
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
return
}
},
}
testCase3 := test{
name: "Custom user id claim",
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
check: func(t *testing.T, c test) {
if c.extractor.userIDClaim != "customUserId" {
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
return
}
},
}
testCase4 := test{
name: "Custom extractor from request context",
extractor: NewClaimsExtractor(
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
return AuthorizationClaims{
UserId: "testCustomRequest",
}
})),
check: func(t *testing.T, c test) {
claims := c.extractor.FromRequestContext(&http.Request{})
if claims.UserId != "testCustomRequest" {
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
return
}
},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.check(t, testCase)
})
}
}