mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 10:20:09 +01:00
Feature: add custom id claim (#667)
This feature allows using the custom claim in the JWT token as a user ID. Refactor claims extractor with options support Add is_current to the user API response
This commit is contained in:
parent
494e56d1be
commit
3ec8274b8e
@ -30,7 +30,7 @@ services:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
ports:
|
||||
- 10000:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
|
||||
# Management
|
||||
@ -45,7 +45,7 @@ services:
|
||||
- ./management.json:/etc/netbird/management.json
|
||||
ports:
|
||||
- $NETBIRD_MGMT_API_PORT:443 #API port
|
||||
# # command for Let's Encrypt validation without dashboard container
|
||||
# # command for Let's Encrypt validation without dashboard container
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
|
||||
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"]
|
||||
# Coturn
|
||||
@ -60,7 +60,6 @@ services:
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
$SIGNAL_VOLUMENAME:
|
||||
|
@ -32,6 +32,7 @@
|
||||
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
|
||||
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
|
||||
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
|
||||
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
|
||||
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"
|
||||
@ -49,4 +50,4 @@
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
||||
NETBIRD_AUTH_AUDIENCE=""
|
||||
# e.g. netbird-client
|
||||
NETBIRD_AUTH_CLIENT_ID=""
|
||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
||||
# indicates whether to use Auth0 or not: true or false
|
||||
NETBIRD_USE_AUTH0="false"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
|
||||
|
@ -7,15 +7,6 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity/sqlite"
|
||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
@ -26,6 +17,16 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity/sqlite"
|
||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -178,8 +179,13 @@ var (
|
||||
tlsEnabled = true
|
||||
}
|
||||
|
||||
httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer,
|
||||
config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics)
|
||||
httpAPIAuthCfg := httpapi.AuthCfg{
|
||||
Issuer: config.HttpConfig.AuthIssuer,
|
||||
Audience: config.HttpConfig.AuthAudience,
|
||||
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
|
||||
KeysLocation: config.HttpConfig.AuthKeysLocation,
|
||||
}
|
||||
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
@ -415,7 +421,6 @@ type OIDCConfigResponse struct {
|
||||
|
||||
// fetchOIDCConfig fetches OIDC configuration from the IDP
|
||||
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
|
||||
res, err := http.Get(oidcEndpoint)
|
||||
if err != nil {
|
||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
|
||||
@ -445,7 +450,6 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
}
|
||||
|
||||
return config, nil
|
||||
|
||||
}
|
||||
|
||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||
|
@ -3,6 +3,15 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -14,14 +23,6 @@ import (
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -219,7 +220,6 @@ func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Rou
|
||||
|
||||
// GetRoutesByPrefix return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
|
||||
|
||||
var routes []*route.Route
|
||||
for _, r := range a.Routes {
|
||||
if r.Network.String() == prefix.String() {
|
||||
@ -243,7 +243,6 @@ func (a *Account) GetPeerByIP(peerIP string) *Peer {
|
||||
|
||||
// GetPeerRules returns a list of source or destination rules of a given peer.
|
||||
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
|
||||
|
||||
// Rules are group based so there is no direct access to peers.
|
||||
// First, find all groups that the given peer belongs to
|
||||
peerGroups := make(map[string]struct{})
|
||||
@ -490,7 +489,8 @@ func (a *Account) GetPeer(peerID string) *Peer {
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store) (*DefaultAccountManager, error) {
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||
) (*DefaultAccountManager, error) {
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
@ -544,14 +544,13 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
|
||||
err := am.warmupIDPCache()
|
||||
if err != nil {
|
||||
log.Warnf("failed warming up cache due to error: %v", err)
|
||||
//todo retry?
|
||||
// todo retry?
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return am, nil
|
||||
|
||||
}
|
||||
|
||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||
@ -669,7 +668,6 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
|
||||
}
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
@ -768,7 +766,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
|
||||
primaryDomain bool) error {
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
@ -826,6 +825,9 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
}
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
@ -897,7 +899,9 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
|
||||
|
||||
// GetAccountFromToken returns an account associated with this token
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
||||
|
||||
if claims.UserId == "" {
|
||||
return nil, nil, fmt.Errorf("user ID is empty")
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
// We override incoming domain claims to group users under a single account.
|
||||
@ -943,6 +947,9 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
}
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
@ -995,7 +1002,6 @@ func isDomainValid(domain string) bool {
|
||||
|
||||
// AccountExists checks whether account exists (returns true) or not (returns false)
|
||||
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
|
@ -7,8 +7,13 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type Protocol string
|
||||
type Provider string
|
||||
type (
|
||||
// Protocol type
|
||||
Protocol string
|
||||
|
||||
// Provider authorization flow type
|
||||
Provider string
|
||||
)
|
||||
|
||||
const (
|
||||
UDP Protocol = "udp"
|
||||
@ -45,14 +50,16 @@ type TURNConfig struct {
|
||||
// HttpServerConfig is a config of the HTTP Management service server
|
||||
type HttpServerConfig struct {
|
||||
LetsEncryptDomain string
|
||||
//CertFile is the location of the certificate
|
||||
// CertFile is the location of the certificate
|
||||
CertFile string
|
||||
//CertKey is the location of the certificate private key
|
||||
// CertKey is the location of the certificate private key
|
||||
CertKey string
|
||||
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
|
||||
AuthAudience string
|
||||
// AuthIssuer identifies principal that issued the JWT.
|
||||
// AuthIssuer identifies principal that issued the JWT
|
||||
AuthIssuer string
|
||||
// AuthUserIDClaim is the name of the claim that used as user ID
|
||||
AuthUserIDClaim string
|
||||
// AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT
|
||||
AuthKeysLocation string
|
||||
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
|
||||
|
@ -3,11 +3,11 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
gPeer "google.golang.org/grpc/peer"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
@ -31,12 +31,14 @@ type GRPCServer struct {
|
||||
config *Config
|
||||
turnCredentialsManager TURNCredentialsManager
|
||||
jwtMiddleware *middleware.JWTMiddleware
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
|
||||
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) {
|
||||
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics,
|
||||
) (*GRPCServer, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -66,6 +68,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
}
|
||||
}
|
||||
|
||||
var audience, userIDClaim string
|
||||
if config.HttpConfig != nil {
|
||||
audience = config.HttpConfig.AuthAudience
|
||||
userIDClaim = config.HttpConfig.AuthUserIDClaim
|
||||
}
|
||||
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
)
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@ -74,6 +86,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
config: config,
|
||||
turnCredentialsManager: turnCredentialsManager,
|
||||
jwtMiddleware: jwtMiddleware,
|
||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
@ -113,7 +126,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
peer, err := s.accountManager.GetPeerByKey(peerKey.String())
|
||||
if err != nil {
|
||||
p, _ := gPeer.FromContext(srv.Context())
|
||||
p, _ := gRPCPeer.FromContext(srv.Context())
|
||||
msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String())
|
||||
log.Debug(msg)
|
||||
return msg
|
||||
@ -122,7 +135,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
|
||||
if err != nil {
|
||||
p, _ := gPeer.FromContext(srv.Context())
|
||||
p, _ := gRPCPeer.FromContext(srv.Context())
|
||||
msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String())
|
||||
log.Debug(msg)
|
||||
return msg
|
||||
@ -200,7 +213,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
userID = claims.UserId
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountFromToken(claims)
|
||||
@ -305,7 +318,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
// peer doesn't exist -> check if setup key was provided
|
||||
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
|
||||
// absent setup key or jwt -> permission denied
|
||||
p, _ := gPeer.FromContext(ctx)
|
||||
p, _ := gRPCPeer.FromContext(ctx)
|
||||
msg := status.Errorf(codes.PermissionDenied,
|
||||
"provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided,"+
|
||||
" remote addr is %s", peerKey.String(), p.Addr.String())
|
||||
|
@ -46,6 +46,10 @@ components:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
is_current:
|
||||
description: Is true if authenticated user is the same as this user
|
||||
type: boolean
|
||||
readOnly: true
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
@ -1703,4 +1707,4 @@ paths:
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
|
@ -566,6 +566,9 @@ type User struct {
|
||||
// Id User ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// IsCurrent Is true if authenticated user is the same as this user
|
||||
IsCurrent *bool `json:"is_current,omitempty"`
|
||||
|
||||
// Name User's name from idp provider
|
||||
Name string `json:"name"`
|
||||
|
||||
|
@ -2,33 +2,35 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// DNSSettings is a handler that returns the DNS settings of the account
|
||||
type DNSSettings struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewDNSSettings returns a new instance of DNSSettings handler
|
||||
func NewDNSSettings(accountManager server.AccountManager, authAudience string) *DNSSettings {
|
||||
func NewDNSSettings(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettings {
|
||||
return &DNSSettings{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetDNSSettings returns the DNS settings for the account
|
||||
func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@ -51,7 +53,7 @@ func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// UpdateDNSSettings handles update to DNS settings of an account
|
||||
func (h *DNSSettings) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
@ -3,14 +3,15 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -52,16 +53,15 @@ func initDNSSettingsTestData() *DNSSettings {
|
||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testDNSSettingsAccountID,
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,34 +2,36 @@ package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Events HTTP handler
|
||||
type Events struct {
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewEvents creates a new Events HTTP handler
|
||||
func NewEvents(accountManager server.AccountManager, authAudience string) *Events {
|
||||
func NewEvents(accountManager server.AccountManager, authCfg AuthCfg) *Events {
|
||||
return &Events{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetEvents list of the given account
|
||||
func (h *Events) GetEvents(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
|
@ -2,6 +2,13 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@ -9,12 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *Events {
|
||||
@ -36,16 +37,15 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_account",
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,7 +244,6 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
assert.Equal(t, expected.Meta["some"], event.Meta["some"])
|
||||
assert.True(t, expected.Timestamp.Equal(event.Timestamp))
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,11 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -17,22 +18,23 @@ import (
|
||||
|
||||
// Groups is a handler that returns groups of the account
|
||||
type Groups struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
|
||||
func NewGroups(accountManager server.AccountManager, authCfg AuthCfg) *Groups {
|
||||
return &Groups{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllGroupsHandler list for the account
|
||||
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@ -50,7 +52,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// UpdateGroupHandler handles update to a group identified by a given ID
|
||||
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -119,7 +121,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// PatchGroupHandler handles patch updates to a group identified by a given ID
|
||||
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -223,7 +225,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// CreateGroupHandler handles group creation request
|
||||
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -265,7 +267,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// DeleteGroupHandler handles group deletion request
|
||||
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -301,7 +303,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// GetGroupHandler returns a group
|
||||
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -13,6 +11,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
@ -78,20 +79,20 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
||||
},
|
||||
Groups: map[string]*server.Group{
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"}},
|
||||
"id-all": {ID: "id-all", Name: "All"},
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -270,7 +271,8 @@ func TestWriteGroup(t *testing.T) {
|
||||
PeersCount: 2,
|
||||
Peers: []api.PeerMinimum{
|
||||
{Id: "peer-A-ID"},
|
||||
{Id: "peer-B-ID"}},
|
||||
{Id: "peer-B-ID"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -1,22 +1,29 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/rs/cors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// AuthCfg contains parameters for authentication middleware
|
||||
type AuthCfg struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
UserIDClaim string
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string,
|
||||
appMetrics telemetry.AppMetrics) (http.Handler, error) {
|
||||
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
||||
jwtMiddleware, err := middleware.NewJwtMiddleware(
|
||||
authIssuer,
|
||||
authAudience,
|
||||
authKeysLocation,
|
||||
)
|
||||
authCfg.Issuer,
|
||||
authCfg.Audience,
|
||||
authCfg.KeysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -24,7 +31,8 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authAudience,
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
accountManager.IsUserAdmin)
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
@ -33,15 +41,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
||||
apiHandler := rootRouter.PathPrefix("/api").Subrouter()
|
||||
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
|
||||
|
||||
groupsHandler := NewGroups(accountManager, authAudience)
|
||||
rulesHandler := NewRules(accountManager, authAudience)
|
||||
peersHandler := NewPeers(accountManager, authAudience)
|
||||
keysHandler := NewSetupKeysHandler(accountManager, authAudience)
|
||||
userHandler := NewUserHandler(accountManager, authAudience)
|
||||
routesHandler := NewRoutes(accountManager, authAudience)
|
||||
nameserversHandler := NewNameservers(accountManager, authAudience)
|
||||
eventsHandler := NewEvents(accountManager, authAudience)
|
||||
dnsSettingsHandler := NewDNSSettings(accountManager, authAudience)
|
||||
groupsHandler := NewGroups(accountManager, authCfg)
|
||||
rulesHandler := NewRules(accountManager, authCfg)
|
||||
peersHandler := NewPeers(accountManager, authCfg)
|
||||
keysHandler := NewSetupKeysHandler(accountManager, authCfg)
|
||||
userHandler := NewUserHandler(accountManager, authCfg)
|
||||
routesHandler := NewRoutes(accountManager, authCfg)
|
||||
nameserversHandler := NewNameservers(accountManager, authCfg)
|
||||
eventsHandler := NewEvents(accountManager, authCfg)
|
||||
dnsSettingsHandler := NewDNSSettings(accountManager, authCfg)
|
||||
|
||||
apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
|
||||
@ -88,7 +96,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
||||
apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS")
|
||||
apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS")
|
||||
|
||||
err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
|
||||
err = apiHandler.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
|
||||
methods, err := route.GetMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -110,5 +118,4 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
|
||||
}
|
||||
|
||||
return rootRouter, nil
|
||||
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
@ -12,17 +13,18 @@ type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
isUserAdmin IsUserAdminFunc
|
||||
audience string
|
||||
isUserAdmin IsUserAdminFunc
|
||||
claimsExtract jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewAccessControl instance constructor
|
||||
func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl {
|
||||
func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
|
||||
return &AccessControl{
|
||||
isUserAdmin: isUserAdmin,
|
||||
audience: audience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
isUserAdmin: isUserAdmin,
|
||||
claimsExtract: *jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(audience),
|
||||
jwtclaims.WithUserIDClaim(userIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,9 +32,9 @@ func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessContr
|
||||
// It also adds
|
||||
func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience)
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
ok, err := a.isUserAdmin(jwtClaims)
|
||||
ok, err := a.isUserAdmin(claims)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
@ -40,7 +42,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
|
||||
if !ok {
|
||||
switch r.Method {
|
||||
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
|
||||
return
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"math/big"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
@ -33,7 +34,6 @@ type JSONWebKey struct {
|
||||
|
||||
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
|
||||
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
|
||||
|
||||
keys, err := getPemKeys(keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -67,13 +67,12 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
|
||||
|
||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
resp, err := http.Get(keysLocation)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var jwks = &Jwks{}
|
||||
jwks := &Jwks{}
|
||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
|
@ -3,6 +3,8 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -11,28 +13,28 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Nameservers is the nameserver group handler of the account
|
||||
type Nameservers struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewNameservers returns a new instance of Nameservers handler
|
||||
func NewNameservers(accountManager server.AccountManager, authAudience string) *Nameservers {
|
||||
func NewNameservers(accountManager server.AccountManager, authCfg AuthCfg) *Nameservers {
|
||||
return &Nameservers{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllNameserversHandler returns the list of nameserver groups for the account
|
||||
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@ -56,7 +58,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
|
||||
|
||||
// CreateNameserverGroupHandler handles nameserver group creation request
|
||||
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -89,7 +91,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
|
||||
|
||||
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
|
||||
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -139,7 +141,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
|
||||
|
||||
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
|
||||
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -221,7 +223,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
|
||||
|
||||
// DeleteNameserverGroupHandler handles nameserver group deletion request
|
||||
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -245,7 +247,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
|
||||
|
||||
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
|
||||
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@ -268,7 +270,6 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
|
||||
}
|
||||
|
||||
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
|
||||
|
@ -3,16 +3,17 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -113,16 +114,15 @@ func initNameserversTestData() *Nameservers {
|
||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testNSGroupAccountID,
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,27 +3,29 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Peers is a handler that returns peers of the account
|
||||
type Peers struct {
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewPeers(accountManager server.AccountManager, authAudience string) *Peers {
|
||||
func NewPeers(accountManager server.AccountManager, authCfg AuthCfg) *Peers {
|
||||
return &Peers{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +57,7 @@ func (h *Peers) deletePeer(accountID, userID string, peerID string, w http.Respo
|
||||
}
|
||||
|
||||
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -78,13 +80,12 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
@ -2,13 +2,14 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
@ -36,16 +37,15 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,9 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@ -9,29 +12,28 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"net/http"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Routes is the routes handler of the account
|
||||
type Routes struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewRoutes returns a new instance of Routes handler
|
||||
func NewRoutes(accountManager server.AccountManager, authAudience string) *Routes {
|
||||
func NewRoutes(accountManager server.AccountManager, authCfg AuthCfg) *Routes {
|
||||
return &Routes{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllRoutesHandler returns the list of routes for the account
|
||||
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -53,7 +55,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// CreateRouteHandler handles route creation request
|
||||
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -92,7 +94,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// UpdateRouteHandler handles update to a route identified by a given ID
|
||||
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -158,7 +160,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// PatchRouteHandler handles patch updates to a route identified by a given ID
|
||||
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -299,7 +301,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// DeleteRouteHandler handles route deletion request
|
||||
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -323,7 +325,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// GetRouteHandler handles a route Get request identified by ID
|
||||
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
@ -4,9 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -14,6 +11,10 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -142,16 +143,15 @@ func initRoutesTestData() *Routes {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,8 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@ -9,27 +11,27 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Rules is a handler that returns rules of the account
|
||||
type Rules struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewRules(accountManager server.AccountManager, authAudience string) *Rules {
|
||||
func NewRules(accountManager server.AccountManager, authCfg AuthCfg) *Rules {
|
||||
return &Rules{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllRulesHandler list for the account
|
||||
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -51,7 +53,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// UpdateRuleHandler handles update to a rule identified by a given ID
|
||||
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -122,7 +124,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// PatchRuleHandler handles patch updates to a rule identified by a given ID
|
||||
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -253,7 +255,6 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
@ -266,7 +267,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// CreateRuleHandler handles rule creation request
|
||||
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -325,7 +326,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// DeleteRuleHandler handles rule deletion request
|
||||
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -350,7 +351,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// GetRuleHandler handles a group Get request identified by ID
|
||||
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
@ -4,13 +4,14 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
@ -71,7 +72,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}},
|
||||
Rules: map[string]*server.Rule{"id-existed": {ID: "id-existed"}},
|
||||
Groups: map[string]*server.Group{
|
||||
"F": {ID: "F"},
|
||||
"G": {ID: "G"},
|
||||
@ -82,16 +83,15 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,7 +264,8 @@ func TestRulesWriteRule(t *testing.T) {
|
||||
Flow: server.TrafficFlowBidirectString,
|
||||
Sources: []api.GroupMinimum{
|
||||
{Id: "G"},
|
||||
{Id: "F"}},
|
||||
{Id: "F"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -306,7 +307,6 @@ func TestRulesWriteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, got, tc.expectedRule)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,34 +2,36 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SetupKeys is a handler that returns a list of setup keys of the account
|
||||
type SetupKeys struct {
|
||||
accountManager server.AccountManager
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
authAudience string
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewSetupKeysHandler(accountManager server.AccountManager, authAudience string) *SetupKeys {
|
||||
func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeys {
|
||||
return &SetupKeys{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
|
||||
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -72,7 +74,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
|
||||
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -97,7 +99,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
|
||||
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -144,8 +146,7 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
|
||||
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
|
@ -4,16 +4,17 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -28,7 +29,8 @@ const (
|
||||
)
|
||||
|
||||
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
|
||||
user *server.User) *SetupKeys {
|
||||
user *server.User,
|
||||
) *SetupKeys {
|
||||
return &SetupKeys{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
@ -43,11 +45,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
},
|
||||
Groups: map[string]*server.Group{
|
||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"}},
|
||||
"id-all": {ID: "id-all", Name: "All"},
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string) (*server.SetupKey, error) {
|
||||
_ int, _ string,
|
||||
) (*server.SetupKey, error) {
|
||||
if keyName == newKey.Name || typ != newKey.Type {
|
||||
return newKey, nil
|
||||
}
|
||||
@ -75,16 +79,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
return []*server.SetupKey{defaultKey}, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: user.Id,
|
||||
Domain: "hotmail.com",
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -209,7 +212,6 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
assertKeys(t, got[0], tc.expectedSetupKeys[0])
|
||||
return
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,27 +2,29 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewUserHandler(accountManager server.AccountManager, authAudience string) *UserHandler {
|
||||
func NewUserHandler(accountManager server.AccountManager, authCfg AuthCfg) *UserHandler {
|
||||
return &UserHandler{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,7 +35,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -69,8 +71,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, toUserResponse(newUser))
|
||||
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
@ -80,7 +81,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -109,7 +110,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, toUserResponse(newUser))
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// GetUsers returns a list of users of the account this user belongs to.
|
||||
@ -120,7 +121,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@ -135,14 +136,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
users = append(users, toUserResponse(r))
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, users)
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo) *api.User {
|
||||
|
||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
autoGroups := user.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
@ -158,6 +158,7 @@ func toUserResponse(user *server.UserInfo) *api.User {
|
||||
userStatus = api.UserStatusDisabled
|
||||
}
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
@ -165,5 +166,6 @@ func toUserResponse(user *server.UserInfo) *api.User {
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
}
|
||||
}
|
||||
|
@ -40,16 +40,15 @@ func initUsers(user ...*server.User) *UserHandler {
|
||||
return users, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "1",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,7 +56,7 @@ func TestGetUsers(t *testing.T) {
|
||||
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}}
|
||||
userHandler := initUsers(users...)
|
||||
|
||||
var tt = []struct {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
|
@ -1,8 +1,9 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -14,51 +15,85 @@ const (
|
||||
)
|
||||
|
||||
// Extract function type
|
||||
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims
|
||||
type ExtractClaims func(r *http.Request) AuthorizationClaims
|
||||
|
||||
// ClaimsExtractor struct that holds the extract function
|
||||
type ClaimsExtractor struct {
|
||||
ExtractClaimsFromRequestContext ExtractClaims
|
||||
authAudience string
|
||||
userIDClaim string
|
||||
|
||||
FromRequestContext ExtractClaims
|
||||
}
|
||||
|
||||
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
|
||||
type ClaimsExtractorOption func(*ClaimsExtractor)
|
||||
|
||||
// WithAudience sets the audience for the extractor
|
||||
func WithAudience(audience string) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.authAudience = audience
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserIDClaim sets the user id claim for the extractor
|
||||
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.userIDClaim = userIDClaim
|
||||
}
|
||||
}
|
||||
|
||||
// WithFromRequestContext sets the function that extracts claims from the request context
|
||||
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
|
||||
return func(c *ClaimsExtractor) {
|
||||
c.FromRequestContext = ec
|
||||
}
|
||||
}
|
||||
|
||||
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
|
||||
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
|
||||
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor {
|
||||
var extractFunc ExtractClaims
|
||||
if extractFunc = e; extractFunc == nil {
|
||||
extractFunc = ExtractClaimsFromRequestContext
|
||||
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
ce := &ClaimsExtractor{}
|
||||
for _, option := range options {
|
||||
option(ce)
|
||||
}
|
||||
|
||||
return &ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: extractFunc,
|
||||
if ce.FromRequestContext == nil {
|
||||
ce.FromRequestContext = ce.fromRequestContext
|
||||
}
|
||||
if ce.userIDClaim == "" {
|
||||
ce.userIDClaim = UserIDClaim
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func ExtractClaimsFromRequestContext(r *http.Request, authAudience string) AuthorizationClaims {
|
||||
if r.Context().Value(TokenUserProperty) == nil {
|
||||
return AuthorizationClaims{}
|
||||
}
|
||||
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||
return ExtractClaimsWithToken(token, authAudience)
|
||||
}
|
||||
|
||||
// ExtractClaimsWithToken extracts claims from the token (after auth)
|
||||
func ExtractClaimsWithToken(token *jwt.Token, authAudience string) AuthorizationClaims {
|
||||
// FromToken extracts claims from the token (after auth)
|
||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{}
|
||||
jwtClaims.UserId = claims[UserIDClaim].(string)
|
||||
accountIdClaim, ok := claims[authAudience+AccountIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.AccountId = accountIdClaim.(string)
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return jwtClaims
|
||||
}
|
||||
domainClaim, ok := claims[authAudience+DomainIDSuffix]
|
||||
jwtClaims.UserId = userID
|
||||
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.AccountId = accountIDClaim.(string)
|
||||
}
|
||||
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
|
||||
if ok {
|
||||
jwtClaims.Domain = domainClaim.(string)
|
||||
}
|
||||
domainCategoryClaim, ok := claims[authAudience+DomainCategorySuffix]
|
||||
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
|
||||
if ok {
|
||||
jwtClaims.DomainCategory = domainCategoryClaim.(string)
|
||||
}
|
||||
return jwtClaims
|
||||
}
|
||||
|
||||
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
|
||||
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
|
||||
if r.Context().Value(TokenUserProperty) == nil {
|
||||
return AuthorizationClaims{}
|
||||
}
|
||||
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
|
||||
return c.FromToken(token)
|
||||
}
|
||||
|
@ -2,10 +2,11 @@ package jwtclaims
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
|
||||
@ -31,7 +32,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
|
||||
}
|
||||
|
||||
func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
inputAuthorizationClaims AuthorizationClaims
|
||||
@ -99,12 +99,84 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
|
||||
|
||||
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance)
|
||||
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
|
||||
extractedClaims := extractor.FromRequestContext(request)
|
||||
|
||||
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsSetOptions(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
extractor *ClaimsExtractor
|
||||
check func(t *testing.T, c test)
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "No custom options",
|
||||
extractor: NewClaimsExtractor(),
|
||||
check: func(t *testing.T, c test) {
|
||||
if c.extractor.authAudience != "" {
|
||||
t.Error("audience should be empty")
|
||||
return
|
||||
}
|
||||
if c.extractor.userIDClaim != UserIDClaim {
|
||||
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
|
||||
return
|
||||
}
|
||||
if c.extractor.FromRequestContext == nil {
|
||||
t.Error("from request context should not be nil")
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Custom audience",
|
||||
extractor: NewClaimsExtractor(WithAudience("https://login/")),
|
||||
check: func(t *testing.T, c test) {
|
||||
if c.extractor.authAudience != "https://login/" {
|
||||
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Custom user id claim",
|
||||
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
|
||||
check: func(t *testing.T, c test) {
|
||||
if c.extractor.userIDClaim != "customUserId" {
|
||||
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
testCase4 := test{
|
||||
name: "Custom extractor from request context",
|
||||
extractor: NewClaimsExtractor(
|
||||
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
|
||||
return AuthorizationClaims{
|
||||
UserId: "testCustomRequest",
|
||||
}
|
||||
})),
|
||||
check: func(t *testing.T, c test) {
|
||||
claims := c.extractor.FromRequestContext(&http.Request{})
|
||||
if claims.UserId != "testCustomRequest" {
|
||||
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.check(t, testCase)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user