Feature: add custom id claim (#667)

This feature allows using the custom claim in the JWT token as a user ID.

Refactor claims extractor with options support

Add is_current to the user API response
This commit is contained in:
Givi Khojanashvili 2023-02-04 00:47:20 +04:00 committed by GitHub
parent 494e56d1be
commit 3ec8274b8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 474 additions and 305 deletions

View File

@ -60,7 +60,6 @@ services:
network_mode: host network_mode: host
command: command:
- -c /etc/turnserver.conf - -c /etc/turnserver.conf
volumes: volumes:
$MGMT_VOLUMENAME: $MGMT_VOLUMENAME:
$SIGNAL_VOLUMENAME: $SIGNAL_VOLUMENAME:

View File

@ -32,6 +32,7 @@
"AuthIssuer": "$NETBIRD_AUTH_AUTHORITY", "AuthIssuer": "$NETBIRD_AUTH_AUTHORITY",
"AuthAudience": "$NETBIRD_AUTH_AUDIENCE", "AuthAudience": "$NETBIRD_AUTH_AUDIENCE",
"AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS", "AuthKeysLocation": "$NETBIRD_AUTH_JWT_CERTS",
"AuthUserIDClaim": "$NETBIRD_AUTH_USER_ID_CLAIM",
"CertFile":"$NETBIRD_MGMT_API_CERT_FILE", "CertFile":"$NETBIRD_MGMT_API_CERT_FILE",
"CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE", "CertKey":"$NETBIRD_MGMT_API_CERT_KEY_FILE",
"OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT" "OIDCConfigEndpoint":"$NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT"

View File

@ -7,6 +7,8 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
NETBIRD_AUTH_AUDIENCE="" NETBIRD_AUTH_AUDIENCE=""
# e.g. netbird-client # e.g. netbird-client
NETBIRD_AUTH_CLIENT_ID="" NETBIRD_AUTH_CLIENT_ID=""
# if you want to use a custom claim for the user ID instead of 'sub', set it here
# NETBIRD_AUTH_USER_ID_CLAIM=""
# indicates whether to use Auth0 or not: true or false # indicates whether to use Auth0 or not: true or false
NETBIRD_USE_AUTH0="false" NETBIRD_USE_AUTH0="false"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"

View File

@ -7,15 +7,6 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"io" "io"
"io/fs" "io/fs"
"net" "net"
@ -26,6 +17,16 @@ import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/telemetry"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -178,8 +179,13 @@ var (
tlsEnabled = true tlsEnabled = true
} }
httpAPIHandler, err := httpapi.APIHandler(accountManager, config.HttpConfig.AuthIssuer, httpAPIAuthCfg := httpapi.AuthCfg{
config.HttpConfig.AuthAudience, config.HttpConfig.AuthKeysLocation, appMetrics) Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }
@ -415,7 +421,6 @@ type OIDCConfigResponse struct {
// fetchOIDCConfig fetches OIDC configuration from the IDP // fetchOIDCConfig fetches OIDC configuration from the IDP
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint) res, err := http.Get(oidcEndpoint)
if err != nil { if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err) return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration fro mendpoint %s %v", oidcEndpoint, err)
@ -445,7 +450,6 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
} }
return config, nil return config, nil
} }
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {

View File

@ -3,6 +3,15 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net"
"net/netip"
"reflect"
"regexp"
"strings"
"sync"
"time"
"github.com/eko/gocache/v3/cache" "github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -14,14 +23,6 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"math/rand"
"net"
"net/netip"
"reflect"
"regexp"
"strings"
"sync"
"time"
) )
const ( const (
@ -219,7 +220,6 @@ func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Rou
// GetRoutesByPrefix return list of routes by account and route prefix // GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route var routes []*route.Route
for _, r := range a.Routes { for _, r := range a.Routes {
if r.Network.String() == prefix.String() { if r.Network.String() == prefix.String() {
@ -243,7 +243,6 @@ func (a *Account) GetPeerByIP(peerIP string) *Peer {
// GetPeerRules returns a list of source or destination rules of a given peer. // GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) { func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers. // Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to // First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{}) peerGroups := make(map[string]struct{})
@ -490,7 +489,8 @@ func (a *Account) GetPeer(peerID string) *Peer {
// BuildManager creates a new DefaultAccountManager with a provided Store // BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store) (*DefaultAccountManager, error) { singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
@ -551,7 +551,6 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
} }
return am, nil return am, nil
} }
// newAccount creates a new Account with a generated ID and generated default setup keys. // newAccount creates a new Account with a generated ID and generated default setup keys.
@ -669,7 +668,6 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI
} }
return nil, nil return nil, nil
} }
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
@ -768,7 +766,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account // updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool) error { primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain) lowerDomain := strings.ToLower(claims.Domain)
@ -826,6 +825,9 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain. // otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
var ( var (
account *Account account *Account
err error err error
@ -897,7 +899,9 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
// GetAccountFromToken returns an account associated with this token // GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account. // We override incoming domain claims to group users under a single account.
@ -943,6 +947,9 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
@ -995,7 +1002,6 @@ func isDomainValid(domain string) bool {
// AccountExists checks whether account exists (returns true) or not (returns false) // AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) { func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()

View File

@ -7,8 +7,13 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
type Protocol string type (
type Provider string // Protocol type
Protocol string
// Provider authorization flow type
Provider string
)
const ( const (
UDP Protocol = "udp" UDP Protocol = "udp"
@ -51,8 +56,10 @@ type HttpServerConfig struct {
CertKey string CertKey string
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT) // AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
AuthAudience string AuthAudience string
// AuthIssuer identifies principal that issued the JWT. // AuthIssuer identifies principal that issued the JWT
AuthIssuer string AuthIssuer string
// AuthUserIDClaim is the name of the claim that used as user ID
AuthUserIDClaim string
// AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT // AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT
AuthKeysLocation string AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration // OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration

View File

@ -3,11 +3,11 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/telemetry"
gPeer "google.golang.org/grpc/peer"
"strings" "strings"
"time" "time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -31,12 +31,14 @@ type GRPCServer struct {
config *Config config *Config
turnCredentialsManager TURNCredentialsManager turnCredentialsManager TURNCredentialsManager
jwtMiddleware *middleware.JWTMiddleware jwtMiddleware *middleware.JWTMiddleware
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
} }
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager,
turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics) (*GRPCServer, error) { turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics,
) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,6 +68,16 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
} }
} }
var audience, userIDClaim string
if config.HttpConfig != nil {
audience = config.HttpConfig.AuthAudience
userIDClaim = config.HttpConfig.AuthUserIDClaim
}
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
return &GRPCServer{ return &GRPCServer{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
@ -74,6 +86,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
config: config, config: config,
turnCredentialsManager: turnCredentialsManager, turnCredentialsManager: turnCredentialsManager,
jwtMiddleware: jwtMiddleware, jwtMiddleware: jwtMiddleware,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics, appMetrics: appMetrics,
}, nil }, nil
} }
@ -113,7 +126,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, err := s.accountManager.GetPeerByKey(peerKey.String()) peer, err := s.accountManager.GetPeerByKey(peerKey.String())
if err != nil { if err != nil {
p, _ := gPeer.FromContext(srv.Context()) p, _ := gRPCPeer.FromContext(srv.Context())
msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String()) msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg) log.Debug(msg)
return msg return msg
@ -122,7 +135,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{} syncReq := &proto.SyncRequest{}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq) err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
if err != nil { if err != nil {
p, _ := gPeer.FromContext(srv.Context()) p, _ := gRPCPeer.FromContext(srv.Context())
msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String()) msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg) log.Debug(msg)
return msg return msg
@ -200,7 +213,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
} }
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) claims := s.jwtClaimsExtractor.FromToken(token)
userID = claims.UserId userID = claims.UserId
// we need to call this method because if user is new, we will automatically add it to existing or create a new account // we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims) _, _, err = s.accountManager.GetAccountFromToken(claims)
@ -305,7 +318,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
// peer doesn't exist -> check if setup key was provided // peer doesn't exist -> check if setup key was provided
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" { if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
// absent setup key or jwt -> permission denied // absent setup key or jwt -> permission denied
p, _ := gPeer.FromContext(ctx) p, _ := gRPCPeer.FromContext(ctx)
msg := status.Errorf(codes.PermissionDenied, msg := status.Errorf(codes.PermissionDenied,
"provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided,"+ "provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided,"+
" remote addr is %s", peerKey.String(), p.Addr.String()) " remote addr is %s", peerKey.String(), p.Addr.String())

View File

@ -46,6 +46,10 @@ components:
type: array type: array
items: items:
type: string type: string
is_current:
description: Is true if authenticated user is the same as this user
type: boolean
readOnly: true
required: required:
- id - id
- email - email

View File

@ -566,6 +566,9 @@ type User struct {
// Id User ID // Id User ID
Id string `json:"id"` Id string `json:"id"`
// IsCurrent Is true if authenticated user is the same as this user
IsCurrent *bool `json:"is_current,omitempty"`
// Name User's name from idp provider // Name User's name from idp provider
Name string `json:"name"` Name string `json:"name"`

View File

@ -2,33 +2,35 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http"
) )
// DNSSettings is a handler that returns the DNS settings of the account // DNSSettings is a handler that returns the DNS settings of the account
type DNSSettings struct { type DNSSettings struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewDNSSettings returns a new instance of DNSSettings handler // NewDNSSettings returns a new instance of DNSSettings handler
func NewDNSSettings(accountManager server.AccountManager, authAudience string) *DNSSettings { func NewDNSSettings(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettings {
return &DNSSettings{ return &DNSSettings{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// GetDNSSettings returns the DNS settings for the account // GetDNSSettings returns the DNS settings for the account
func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -51,7 +53,7 @@ func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
// UpdateDNSSettings handles update to DNS settings of an account // UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettings) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettings) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@ -3,14 +3,15 @@ package http
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -52,16 +53,15 @@ func initDNSSettingsTestData() *DNSSettings {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testDNSSettingsAccountID, AccountId: testDNSSettingsAccountID,
} }
}, }),
}, ),
} }
} }

View File

@ -2,34 +2,36 @@ package http
import ( import (
"fmt" "fmt"
"net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http"
) )
// Events HTTP handler // Events HTTP handler
type Events struct { type Events struct {
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
jwtExtractor jwtclaims.ClaimsExtractor
} }
// NewEvents creates a new Events HTTP handler // NewEvents creates a new Events HTTP handler
func NewEvents(accountManager server.AccountManager, authAudience string) *Events { func NewEvents(accountManager server.AccountManager, authCfg AuthCfg) *Events {
return &Events{ return &Events{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// GetEvents list of the given account // GetEvents list of the given account
func (h *Events) GetEvents(w http.ResponseWriter, r *http.Request) { func (h *Events) GetEvents(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)

View File

@ -2,6 +2,13 @@ package http
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
@ -9,12 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
) )
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *Events { func initEventsTestData(account string, user *server.User, events ...*activity.Event) *Events {
@ -36,16 +37,15 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}, user, nil }, user, nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_account", AccountId: "test_account",
} }
}, }),
}, ),
} }
} }
@ -244,7 +244,6 @@ func TestEvents_GetEvents(t *testing.T) {
assert.Equal(t, expected.Meta["some"], event.Meta["some"]) assert.Equal(t, expected.Meta["some"], event.Meta["some"])
assert.True(t, expected.Timestamp.Equal(event.Timestamp)) assert.True(t, expected.Timestamp.Equal(event.Timestamp))
} }
}) })
} }
} }

View File

@ -2,10 +2,11 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -17,22 +18,23 @@ import (
// Groups is a handler that returns groups of the account // Groups is a handler that returns groups of the account
type Groups struct { type Groups struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
} }
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups { func NewGroups(accountManager server.AccountManager, authCfg AuthCfg) *Groups {
return &Groups{ return &Groups{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// GetAllGroupsHandler list for the account // GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -50,7 +52,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID // UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -119,7 +121,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID // PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -223,7 +225,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request // CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -265,7 +267,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request // DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -301,7 +303,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group // GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -13,6 +11,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -78,20 +79,20 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
}, },
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"},
},
}, user, nil }, user, nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
}, }),
}, ),
} }
} }
@ -270,7 +271,8 @@ func TestWriteGroup(t *testing.T) {
PeersCount: 2, PeersCount: 2,
Peers: []api.PeerMinimum{ Peers: []api.PeerMinimum{
{Id: "peer-A-ID"}, {Id: "peer-A-ID"},
{Id: "peer-B-ID"}}, {Id: "peer-B-ID"},
},
}, },
}, },
} }

View File

@ -1,22 +1,29 @@
package http package http
import ( import (
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/cors" "github.com/rs/cors"
"net/http"
) )
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience string, authKeysLocation string, func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
appMetrics telemetry.AppMetrics) (http.Handler, error) {
jwtMiddleware, err := middleware.NewJwtMiddleware( jwtMiddleware, err := middleware.NewJwtMiddleware(
authIssuer, authCfg.Issuer,
authAudience, authCfg.Audience,
authKeysLocation, authCfg.KeysLocation)
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -24,7 +31,8 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
acMiddleware := middleware.NewAccessControl( acMiddleware := middleware.NewAccessControl(
authAudience, authCfg.Audience,
authCfg.UserIDClaim,
accountManager.IsUserAdmin) accountManager.IsUserAdmin)
rootRouter := mux.NewRouter() rootRouter := mux.NewRouter()
@ -33,15 +41,15 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler := rootRouter.PathPrefix("/api").Subrouter() apiHandler := rootRouter.PathPrefix("/api").Subrouter()
apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler) apiHandler.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
groupsHandler := NewGroups(accountManager, authAudience) groupsHandler := NewGroups(accountManager, authCfg)
rulesHandler := NewRules(accountManager, authAudience) rulesHandler := NewRules(accountManager, authCfg)
peersHandler := NewPeers(accountManager, authAudience) peersHandler := NewPeers(accountManager, authCfg)
keysHandler := NewSetupKeysHandler(accountManager, authAudience) keysHandler := NewSetupKeysHandler(accountManager, authCfg)
userHandler := NewUserHandler(accountManager, authAudience) userHandler := NewUserHandler(accountManager, authCfg)
routesHandler := NewRoutes(accountManager, authAudience) routesHandler := NewRoutes(accountManager, authCfg)
nameserversHandler := NewNameservers(accountManager, authAudience) nameserversHandler := NewNameservers(accountManager, authCfg)
eventsHandler := NewEvents(accountManager, authAudience) eventsHandler := NewEvents(accountManager, authCfg)
dnsSettingsHandler := NewDNSSettings(accountManager, authAudience) dnsSettingsHandler := NewDNSSettings(accountManager, authCfg)
apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer). apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer).
@ -88,7 +96,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS")
err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { err = apiHandler.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods() methods, err := route.GetMethods()
if err != nil { if err != nil {
return err return err
@ -110,5 +118,4 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
} }
return rootRouter, nil return rootRouter, nil
} }

View File

@ -1,9 +1,10 @@
package middleware package middleware
import ( import (
"net/http"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
@ -12,17 +13,18 @@ type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct { type AccessControl struct {
jwtExtractor jwtclaims.ClaimsExtractor
isUserAdmin IsUserAdminFunc isUserAdmin IsUserAdminFunc
audience string claimsExtract jwtclaims.ClaimsExtractor
} }
// NewAccessControl instance constructor // NewAccessControl instance constructor
func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessControl { func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl {
return &AccessControl{ return &AccessControl{
isUserAdmin: isUserAdmin, isUserAdmin: isUserAdmin,
audience: audience, claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
),
} }
} }
@ -30,9 +32,9 @@ func NewAccessControl(audience string, isUserAdmin IsUserAdminFunc) *AccessContr
// It also adds // It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler { func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwtClaims := a.jwtExtractor.ExtractClaimsFromRequestContext(r, a.audience) claims := a.claimsExtract.FromRequestContext(r)
ok, err := a.isUserAdmin(jwtClaims) ok, err := a.isUserAdmin(claims)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return
@ -40,7 +42,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
if !ok { if !ok {
switch r.Method { switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return return

View File

@ -10,10 +10,11 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"math/big" "math/big"
"net/http" "net/http"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
) )
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
@ -33,7 +34,6 @@ type JSONWebKey struct {
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header // NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
@ -67,13 +67,12 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
func getPemKeys(keysLocation string) (*Jwks, error) { func getPemKeys(keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation) resp, err := http.Get(keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
var jwks = &Jwks{} jwks := &Jwks{}
err = json.NewDecoder(resp.Body).Decode(jwks) err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil { if err != nil {
return jwks, err return jwks, err

View File

@ -3,6 +3,8 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -11,28 +13,28 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http"
) )
// Nameservers is the nameserver group handler of the account // Nameservers is the nameserver group handler of the account
type Nameservers struct { type Nameservers struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewNameservers returns a new instance of Nameservers handler // NewNameservers returns a new instance of Nameservers handler
func NewNameservers(accountManager server.AccountManager, authAudience string) *Nameservers { func NewNameservers(accountManager server.AccountManager, authCfg AuthCfg) *Nameservers {
return &Nameservers{ return &Nameservers{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// GetAllNameserversHandler returns the list of nameserver groups for the account // GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -56,7 +58,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
// CreateNameserverGroupHandler handles nameserver group creation request // CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -89,7 +91,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -139,7 +141,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -221,7 +223,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
// DeleteNameserverGroupHandler handles nameserver group deletion request // DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -245,7 +247,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID // GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -268,7 +270,6 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {

View File

@ -3,16 +3,17 @@ package http
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -113,16 +114,15 @@ func initNameserversTestData() *Nameservers {
return testingNSAccount, testingAccount.Users["test_user"], nil return testingNSAccount, testingAccount.Users["test_user"], nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testNSGroupAccountID, AccountId: testNSGroupAccountID,
} }
}, }),
}, ),
} }
} }

View File

@ -3,27 +3,29 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"net/http"
) )
// Peers is a handler that returns peers of the account // Peers is a handler that returns peers of the account
type Peers struct { type Peers struct {
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
jwtExtractor jwtclaims.ClaimsExtractor
} }
func NewPeers(accountManager server.AccountManager, authAudience string) *Peers { func NewPeers(accountManager server.AccountManager, authCfg AuthCfg) *Peers {
return &Peers{ return &Peers{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
@ -55,7 +57,7 @@ func (h *Peers) deletePeer(accountID, userID string, peerID string, w http.Respo
} }
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -78,13 +80,12 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
default: default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
} }
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@ -2,13 +2,14 @@ package http
import ( import (
"encoding/json" "encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
@ -36,16 +37,15 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
}, user, nil }, user, nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
}, }),
}, ),
} }
} }

View File

@ -2,6 +2,9 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http"
"unicode/utf8"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@ -9,29 +12,28 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"net/http"
"unicode/utf8"
) )
// Routes is the routes handler of the account // Routes is the routes handler of the account
type Routes struct { type Routes struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewRoutes returns a new instance of Routes handler // NewRoutes returns a new instance of Routes handler
func NewRoutes(accountManager server.AccountManager, authAudience string) *Routes { func NewRoutes(accountManager server.AccountManager, authCfg AuthCfg) *Routes {
return &Routes{ return &Routes{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// GetAllRoutesHandler returns the list of routes for the account // GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -53,7 +55,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request // CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -92,7 +94,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID // UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -158,7 +160,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID // PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -299,7 +301,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request // DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -323,7 +325,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID // GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@ -4,9 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -14,6 +11,10 @@ import (
"strconv" "strconv"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -142,16 +143,15 @@ func initRoutesTestData() *Routes {
return testingAccount, testingAccount.Users["test_user"], nil return testingAccount, testingAccount.Users["test_user"], nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testAccountID, AccountId: testAccountID,
} }
}, }),
}, ),
} }
} }

View File

@ -2,6 +2,8 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@ -9,27 +11,27 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
"net/http"
) )
// Rules is a handler that returns rules of the account // Rules is a handler that returns rules of the account
type Rules struct { type Rules struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
} }
func NewRules(accountManager server.AccountManager, authAudience string) *Rules { func NewRules(accountManager server.AccountManager, authCfg AuthCfg) *Rules {
return &Rules{ return &Rules{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// GetAllRulesHandler list for the account // GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -51,7 +53,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID // UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -122,7 +124,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID // PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims) account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -253,7 +255,6 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
} }
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations) rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
return return
@ -266,7 +267,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request // CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -325,7 +326,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request // DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -350,7 +351,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID // GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@ -4,13 +4,14 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -71,7 +72,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}}, Rules: map[string]*server.Rule{"id-existed": {ID: "id-existed"}},
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"F": {ID: "F"}, "F": {ID: "F"},
"G": {ID: "G"}, "G": {ID: "G"},
@ -82,16 +83,15 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}, user, nil }, user, nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: "test_user",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
}, }),
}, ),
} }
} }
@ -264,7 +264,8 @@ func TestRulesWriteRule(t *testing.T) {
Flow: server.TrafficFlowBidirectString, Flow: server.TrafficFlowBidirectString,
Sources: []api.GroupMinimum{ Sources: []api.GroupMinimum{
{Id: "G"}, {Id: "G"},
{Id: "F"}}, {Id: "F"},
},
}, },
}, },
} }
@ -306,7 +307,6 @@ func TestRulesWriteRule(t *testing.T) {
} }
assert.Equal(t, got, tc.expectedRule) assert.Equal(t, got, tc.expectedRule)
}) })
} }
} }

View File

@ -2,34 +2,36 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"net/http"
"time"
) )
// SetupKeys is a handler that returns a list of setup keys of the account // SetupKeys is a handler that returns a list of setup keys of the account
type SetupKeys struct { type SetupKeys struct {
accountManager server.AccountManager accountManager server.AccountManager
jwtExtractor jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
authAudience string
} }
func NewSetupKeysHandler(accountManager server.AccountManager, authAudience string) *SetupKeys { func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeys {
return &SetupKeys{ return &SetupKeys{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey // CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -72,7 +74,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID // GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -97,7 +99,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey // UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -144,8 +146,7 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)

View File

@ -4,16 +4,17 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -28,7 +29,8 @@ const (
) )
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys { user *server.User,
) *SetupKeys {
return &SetupKeys{ return &SetupKeys{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
@ -43,11 +45,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}, },
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"},
},
}, user, nil }, user, nil
}, },
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string) (*server.SetupKey, error) { _ int, _ string,
) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type { if keyName == newKey.Name || typ != newKey.Type {
return newKey, nil return newKey, nil
} }
@ -75,16 +79,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
return []*server.SetupKey{defaultKey}, nil return []*server.SetupKey{defaultKey}, nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: user.Id, UserId: user.Id,
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testAccountID, AccountId: testAccountID,
} }
}, }),
}, ),
} }
} }
@ -209,7 +212,6 @@ func TestSetupKeysHandlers(t *testing.T) {
assertKeys(t, got[0], tc.expectedSetupKeys[0]) assertKeys(t, got[0], tc.expectedSetupKeys[0])
return return
} }
}) })
} }
} }

View File

@ -2,11 +2,12 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -14,15 +15,16 @@ import (
type UserHandler struct { type UserHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
authAudience string claimsExtractor *jwtclaims.ClaimsExtractor
jwtExtractor jwtclaims.ClaimsExtractor
} }
func NewUserHandler(accountManager server.AccountManager, authAudience string) *UserHandler { func NewUserHandler(accountManager server.AccountManager, authCfg AuthCfg) *UserHandler {
return &UserHandler{ return &UserHandler{
accountManager: accountManager, accountManager: accountManager,
authAudience: authAudience, claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
} }
} }
@ -33,7 +35,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -69,8 +71,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(err, w) util.WriteError(err, w)
return return
} }
util.WriteJSONObject(w, toUserResponse(newUser)) util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
} }
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite). // CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
@ -80,7 +81,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
return return
} }
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -109,7 +110,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
util.WriteError(err, w) util.WriteError(err, w)
return return
} }
util.WriteJSONObject(w, toUserResponse(newUser)) util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
} }
// GetUsers returns a list of users of the account this user belongs to. // GetUsers returns a list of users of the account this user belongs to.
@ -120,7 +121,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims) account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@ -135,14 +136,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
users := make([]*api.User, 0) users := make([]*api.User, 0)
for _, r := range data { for _, r := range data {
users = append(users, toUserResponse(r)) users = append(users, toUserResponse(r, claims.UserId))
} }
util.WriteJSONObject(w, users) util.WriteJSONObject(w, users)
} }
func toUserResponse(user *server.UserInfo) *api.User { func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
autoGroups := user.AutoGroups autoGroups := user.AutoGroups
if autoGroups == nil { if autoGroups == nil {
autoGroups = []string{} autoGroups = []string{}
@ -158,6 +158,7 @@ func toUserResponse(user *server.UserInfo) *api.User {
userStatus = api.UserStatusDisabled userStatus = api.UserStatusDisabled
} }
isCurrent := user.ID == currenUserID
return &api.User{ return &api.User{
Id: user.ID, Id: user.ID,
Name: user.Name, Name: user.Name,
@ -165,5 +166,6 @@ func toUserResponse(user *server.UserInfo) *api.User {
Role: user.Role, Role: user.Role,
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: userStatus, Status: userStatus,
IsCurrent: &isCurrent,
} }
} }

View File

@ -40,16 +40,15 @@ func initUsers(user ...*server.User) *UserHandler {
return users, nil return users, nil
}, },
}, },
authAudience: "", claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtExtractor: jwtclaims.ClaimsExtractor{ jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "1", UserId: "1",
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
}, }),
}, ),
} }
} }
@ -57,7 +56,7 @@ func TestGetUsers(t *testing.T) {
users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}} users := []*server.User{{Id: "1", Role: "admin"}, {Id: "2", Role: "user"}, {Id: "3", Role: "user"}}
userHandler := initUsers(users...) userHandler := initUsers(users...)
var tt = []struct { tt := []struct {
name string name string
expectedStatus int expectedStatus int
requestType string requestType string

View File

@ -1,8 +1,9 @@
package jwtclaims package jwtclaims
import ( import (
"github.com/golang-jwt/jwt"
"net/http" "net/http"
"github.com/golang-jwt/jwt"
) )
const ( const (
@ -14,51 +15,85 @@ const (
) )
// Extract function type // Extract function type
type ExtractClaims func(r *http.Request, authAudiance string) AuthorizationClaims type ExtractClaims func(r *http.Request) AuthorizationClaims
// ClaimsExtractor struct that holds the extract function // ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct { type ClaimsExtractor struct {
ExtractClaimsFromRequestContext ExtractClaims authAudience string
userIDClaim string
FromRequestContext ExtractClaims
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
type ClaimsExtractorOption func(*ClaimsExtractor)
// WithAudience sets the audience for the extractor
func WithAudience(audience string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.authAudience = audience
}
}
// WithUserIDClaim sets the user id claim for the extractor
func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.userIDClaim = userIDClaim
}
}
// WithFromRequestContext sets the function that extracts claims from the request context
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.FromRequestContext = ec
}
} }
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature, // NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default // then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(e ExtractClaims) *ClaimsExtractor { func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
var extractFunc ExtractClaims ce := &ClaimsExtractor{}
if extractFunc = e; extractFunc == nil { for _, option := range options {
extractFunc = ExtractClaimsFromRequestContext option(ce)
}
if ce.FromRequestContext == nil {
ce.FromRequestContext = ce.fromRequestContext
}
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
} }
return &ClaimsExtractor{ // FromToken extracts claims from the token (after auth)
ExtractClaimsFromRequestContext: extractFunc, func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
}
}
// ExtractClaimsFromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func ExtractClaimsFromRequestContext(r *http.Request, authAudience string) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return ExtractClaimsWithToken(token, authAudience)
}
// ExtractClaimsWithToken extracts claims from the token (after auth)
func ExtractClaimsWithToken(token *jwt.Token, authAudience string) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims) claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{} jwtClaims := AuthorizationClaims{}
jwtClaims.UserId = claims[UserIDClaim].(string) userID, ok := claims[c.userIDClaim].(string)
accountIdClaim, ok := claims[authAudience+AccountIDSuffix] if !ok {
if ok { return jwtClaims
jwtClaims.AccountId = accountIdClaim.(string)
} }
domainClaim, ok := claims[authAudience+DomainIDSuffix] jwtClaims.UserId = userID
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIDClaim.(string)
}
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
if ok { if ok {
jwtClaims.Domain = domainClaim.(string) jwtClaims.Domain = domainClaim.(string)
} }
domainCategoryClaim, ok := claims[authAudience+DomainCategorySuffix] domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
if ok { if ok {
jwtClaims.DomainCategory = domainCategoryClaim.(string) jwtClaims.DomainCategory = domainCategoryClaim.(string)
} }
return jwtClaims return jwtClaims
} }
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return c.FromToken(token)
}

View File

@ -2,10 +2,11 @@ package jwtclaims
import ( import (
"context" "context"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"net/http" "net/http"
"testing" "testing"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
) )
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request { func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance string) *http.Request {
@ -31,7 +32,6 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
} }
func TestExtractClaimsFromRequestContext(t *testing.T) { func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct { type test struct {
name string name string
inputAuthorizationClaims AuthorizationClaims inputAuthorizationClaims AuthorizationClaims
@ -99,12 +99,84 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} { for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance) request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractedClaims := ExtractClaimsFromRequestContext(request, testCase.inputAudiance) extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
extractedClaims := extractor.FromRequestContext(request)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG) testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
}) })
} }
} }
func TestExtractClaimsSetOptions(t *testing.T) {
type test struct {
name string
extractor *ClaimsExtractor
check func(t *testing.T, c test)
}
testCase1 := test{
name: "No custom options",
extractor: NewClaimsExtractor(),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "" {
t.Error("audience should be empty")
return
}
if c.extractor.userIDClaim != UserIDClaim {
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
return
}
if c.extractor.FromRequestContext == nil {
t.Error("from request context should not be nil")
return
}
},
}
testCase2 := test{
name: "Custom audience",
extractor: NewClaimsExtractor(WithAudience("https://login/")),
check: func(t *testing.T, c test) {
if c.extractor.authAudience != "https://login/" {
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
return
}
},
}
testCase3 := test{
name: "Custom user id claim",
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
check: func(t *testing.T, c test) {
if c.extractor.userIDClaim != "customUserId" {
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
return
}
},
}
testCase4 := test{
name: "Custom extractor from request context",
extractor: NewClaimsExtractor(
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
return AuthorizationClaims{
UserId: "testCustomRequest",
}
})),
check: func(t *testing.T, c test) {
claims := c.extractor.FromRequestContext(&http.Request{})
if claims.UserId != "testCustomRequest" {
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
return
}
},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.check(t, testCase)
})
}
}