[management] refactor auth (#3296)

This commit is contained in:
Pedro Maia Costa 2025-02-20 20:24:40 +00:00 committed by GitHub
parent d7d5b1b1d6
commit 77e40f41f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 2085 additions and 1937 deletions

View File

@ -95,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -1226,7 +1226,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
return nil, "", err
}

View File

@ -134,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
return nil, "", err
}

2
go.mod
View File

@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6
github.com/netbirdio/management-integrations/integrations v0.0.0-20250220173202-e599d83524fc
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@ -529,8 +529,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250220173202-e599d83524fc h1:18xvjOy2tZVIK7rihNpf9DF/3mAiljYKWaQlWa9vJgI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250220173202-e599d83524fc/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=

View File

@ -78,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -39,13 +39,12 @@ import (
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@ -255,24 +254,13 @@ var (
tlsEnabled = true
}
jwtValidator, err := jwtclaims.NewJWTValidator(
ctx,
authManager := auth.NewManager(store,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return fmt.Errorf("failed creating JWT validator: %v", err)
}
httpAPIAuthCfg := configs.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
config.HttpConfig.AuthUserIDClaim,
config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled)
userManager := users.NewManager(store)
settingsManager := settings.NewManager(store)
permissionsManager := permissions.NewManager(userManager, settingsManager)
@ -281,7 +269,7 @@ var (
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, config, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
@ -290,7 +278,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}

View File

@ -2,11 +2,8 @@ package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"errors"
"fmt"
"hash/crc32"
"math/rand"
"net"
"net/netip"
@ -24,14 +21,13 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@ -77,13 +73,10 @@ type AccountManager interface {
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
@ -150,6 +143,7 @@ type AccountManager interface {
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
}
type DefaultAccountManager struct {
@ -954,11 +948,11 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
}
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
primaryDomain bool,
) error {
if claims.Domain == "" {
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
if userAuth.Domain == "" {
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth)
return nil
}
@ -971,11 +965,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return err
}
if domainIsUpToDate(accountDomain, domainCategory, claims) {
if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
return nil
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
@ -984,13 +978,13 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
newDomain := accountDomain
newCategoty := domainCategory
lowerDomain := strings.ToLower(claims.Domain)
lowerDomain := strings.ToLower(userAuth.Domain)
if accountDomain != lowerDomain && user.HasAdminPower() {
newDomain = lowerDomain
}
if accountDomain == lowerDomain {
newCategoty = claims.DomainCategory
newCategoty = userAuth.DomainCategory
}
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
@ -1006,16 +1000,16 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
ctx context.Context,
userAccountID string,
domainAccountID string,
claims jwtclaims.AuthorizationClaims,
userAuth nbcontext.UserAuth,
) error {
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
if err != nil {
return err
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
if err != nil {
return err
}
@ -1025,20 +1019,20 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
if claims.UserId == "" {
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
if userAuth.UserId == "" {
return "", fmt.Errorf("user ID is empty")
}
lowerDomain := strings.ToLower(claims.Domain)
lowerDomain := strings.ToLower(userAuth.Domain)
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain)
if err != nil {
return "", err
}
newAccount.Domain = lowerDomain
newAccount.DomainCategory = claims.DomainCategory
newAccount.DomainCategory = userAuth.DomainCategory
newAccount.IsDomainPrimaryAccount = true
err = am.Store.SaveAccount(ctx, newAccount)
@ -1046,33 +1040,33 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
if err != nil {
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
return newAccount.Id, nil
}
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
newUser := types.NewRegularUser(claims.UserId)
newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
if err != nil {
return "", err
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
if err != nil {
return "", err
}
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
}
@ -1112,76 +1106,11 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
return nil
}
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
}
// GetAccount returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
return am.Store.GetAccount(ctx, accountID)
}
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return nil, nil, "", "", err
}
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
return user, pat, domain, category, nil
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
if len(token) != types.PATLength {
return nil, nil, fmt.Errorf("token has incorrect length")
}
prefix := token[:len(types.PATPrefix)]
if prefix != types.PATPrefix {
return nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, fmt.Errorf("token checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
var user *types.User
var pat *types.PersonalAccessToken
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
if err != nil {
return err
}
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
return err
})
if err != nil {
return nil, nil, err
}
return user, pat, nil
}
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
@ -1196,58 +1125,56 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
if userAuth.UserId == "" {
return "", "", errors.New(emptyUserID)
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = types.PrivateCategory
userAuth.Domain = am.singleAccountModeDomain
userAuth.DomainCategory = types.PrivateCategory
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
if err != nil {
return "", "", err
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
}
if userAuth.IsChild {
return accountID, user.Id, nil
}
if user.AccountID != accountID {
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID)
}
if !user.IsServiceUser && claims.Invited {
if !user.IsServiceUser && userAuth.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
return "", "", err
}
}
if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
return "", "", err
}
return accountID, user.Id, nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
if isToken, ok := claim.(bool); ok && isToken {
return nil
}
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
if userAuth.IsChild || userAuth.IsPAT {
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return err
}
@ -1261,9 +1188,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return nil
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId)
defer func() {
if unlockAccount != nil {
unlockAccount()
@ -1275,17 +1200,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
var hasChanges bool
var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, userAuth.Groups)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
@ -1310,7 +1235,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@ -1320,7 +1245,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId)
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
@ -1334,7 +1259,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@ -1352,45 +1277,45 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupAddedToUser, meta)
}
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupRemovedFromUser, meta)
}
}
if settings.GroupsPropagationEnabled {
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
if err != nil {
return err
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.UpdateAccountPeers(ctx, accountID)
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.UpdateAccountPeers(ctx, userAuth.AccountId)
}
}
@ -1415,24 +1340,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
//
// UserAuth IsChild -> checks that account exists
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
if claims.UserId == "" {
if userAuth.UserId == "" {
return "", errors.New(emptyUserID)
}
if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil || !exists {
return "", err
}
return userAuth.AccountId, nil
}
if claims.AccountId != "" {
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
}
if userAuth.AccountId != "" {
return am.handlePrivateAccountWithIDFromClaim(ctx, userAuth)
}
// We checked if the domain has a primary account already
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, userAuth.Domain)
if cancel != nil {
defer cancel()
}
@ -1440,14 +1375,14 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
if userAccountID != "" {
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, userAuth); err != nil {
return "", err
}
@ -1455,10 +1390,10 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
}
if domainAccountID != "" {
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
return am.addNewUserToDomainAccount(ctx, domainAccountID, userAuth)
}
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
}
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
@ -1486,40 +1421,40 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
return domainAccountID, cancel, nil
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
}
if userAccountID != claims.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
if userAccountID != userAuth.AccountId {
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
}
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId)
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err
}
if domainIsUpToDate(accountDomain, domainCategory, claims) {
return claims.AccountId, nil
if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
return userAuth.AccountId, nil
}
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err
}
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
err = am.handleExistingUserAccount(ctx, userAuth.AccountId, domainAccountID, userAuth)
if err != nil {
return "", err
}
return claims.AccountId, nil
return userAuth.AccountId, nil
}
func handleNotFound(err error) error {
@ -1534,8 +1469,8 @@ func handleNotFound(err error) error {
return nil
}
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain
func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@ -1617,34 +1552,6 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if settings != nil && settings.JWTGroupsEnabled {
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}
return nil
}
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
am.UpdateAccountPeers(ctx, accountID)
@ -1802,39 +1709,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
return acc
}
// extractJWTGroups extracts the group names from a JWT token's claims.
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
userJWTGroups := make([]string, 0)
if claim, ok := claims.Raw[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
} else {
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
}
} else {
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
}
return userJWTGroups
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}
// separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs.

View File

@ -2,8 +2,6 @@ package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"encoding/json"
"fmt"
"io"
@ -15,8 +13,6 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/util"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@ -30,7 +26,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@ -437,7 +433,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type initUserParams nbcontext.UserAuth
var (
publicDomain = "public.com"
@ -460,7 +456,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
testCases := []struct {
name string
inputClaims jwtclaims.AuthorizationClaims
inputClaims nbcontext.UserAuth
inputInitUserParams initUserParams
inputUpdateAttrs bool
inputUpdateClaimAccount bool
@ -475,7 +471,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}{
{
name: "New User With Public Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: types.PublicCategory,
@ -492,7 +488,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Unknown Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: types.UnknownCategory,
@ -509,7 +505,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New User With Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@ -526,7 +522,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "New Regular User With Existing Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: types.PrivateCategory,
@ -544,7 +540,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing User With Existing Reclassified Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@ -561,7 +557,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "Existing Account Id With Existing Reclassified Private Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: types.PrivateCategory,
@ -579,7 +575,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
},
{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
inputClaims: nbcontext.UserAuth{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
@ -608,7 +604,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
require.NoError(t, err, "get init account failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@ -616,7 +612,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
accountID, _, err = manager.GetAccountIDFromUserAuth(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@ -635,14 +631,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}
}
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
@ -650,65 +644,50 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
claims := nbcontext.UserAuth{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
Domain: domain,
UserId: userId,
DomainCategory: "test-category",
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
Groups: []string{"group1", "group2"},
}
t.Run("JWT groups disabled", func(t *testing.T) {
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
err := manager.SyncUserJWTGroups(context.Background(), claims)
require.NoError(t, err, "synt user jwt groups failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
err = manager.SyncUserJWTGroups(context.Background(), claims)
require.NoError(t, err, "synt user jwt groups failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
err = manager.SyncUserJWTGroups(context.Background(), claims)
require.NoError(t, err, "synt user jwt groups failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*types.Group{}
for _, g := range account.Groups {
groupsByNames[g.Name] = g
}
g1, ok := groupsByNames["group1"]
require.True(t, ok, "group1 should be added to the account")
require.Equal(t, g1.Name, "group1", "group1 name should match")
require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match")
g2, ok := groupsByNames["group2"]
require.True(t, ok, "group2 should be added to the account")
require.Equal(t, g2.Name, "group2", "group2 name should match")
@ -716,88 +695,6 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
})
}
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &types.User{
Id: "someUser",
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
UserID: "someUser",
HashedToken: encodedHashedToken,
},
},
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
}
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &types.User{
Id: "someUser",
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
},
},
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
}
err = am.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount(context.Background(), "account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
}
func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@ -962,13 +859,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := jwtclaims.AuthorizationClaims{
claims := nbcontext.UserAuth{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: types.PrivateCategory,
}
publicClaims := jwtclaims.AuthorizationClaims{
publicClaims := nbcontext.UserAuth{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: types.PublicCategory,
@ -2683,11 +2580,13 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group3"},
IsPAT: true,
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2696,11 +2595,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("empty jwt groups", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2709,11 +2609,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("jwt match existing api group", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1"},
}
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2729,11 +2630,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1"},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2746,11 +2648,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add jwt group", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group1", "group2"},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2759,11 +2662,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("existed group not update", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{"group2"},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2772,11 +2676,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("add new group", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
claims := nbcontext.UserAuth{
UserId: "user2",
AccountId: "accountID",
Groups: []string{"group1", "group3"},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
@ -2789,11 +2694,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
claims := nbcontext.UserAuth{
UserId: "user1",
AccountId: "accountID",
Groups: []string{},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
@ -2803,11 +2709,12 @@ func TestAccount_SetJWTGroups(t *testing.T) {
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{},
claims := nbcontext.UserAuth{
UserId: "user2",
AccountId: "accountID",
Groups: []string{},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")

View File

@ -1,15 +1,17 @@
package jwtclaims
package jwt
import (
"net/http"
"errors"
"net/url"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
const (
// TokenUserProperty key for the user property in the request context
TokenUserProperty = "user"
// AccountIDSuffix suffix for the account id claim
AccountIDSuffix = "wt_account_id"
// DomainIDSuffix suffix for the domain id claim
@ -22,19 +24,16 @@ const (
LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
// IsToken claim indicates that auth type from the user is a token
IsToken = "is_token"
)
// ExtractClaims Extract function type
type ExtractClaims func(r *http.Request) AuthorizationClaims
var (
errUserIDClaimEmpty = errors.New("user ID claim token value is empty")
)
// ClaimsExtractor struct that holds the extract function
type ClaimsExtractor struct {
authAudience string
userIDClaim string
FromRequestContext ExtractClaims
}
// ClaimsExtractorOption is a function that configures the ClaimsExtractor
@ -54,13 +53,6 @@ func WithUserIDClaim(userIDClaim string) ClaimsExtractorOption {
}
}
// WithFromRequestContext sets the function that extracts claims from the request context
func WithFromRequestContext(ec ExtractClaims) ClaimsExtractorOption {
return func(c *ClaimsExtractor) {
c.FromRequestContext = ec
}
}
// NewClaimsExtractor returns an extractor, and if provided with a function with ExtractClaims signature,
// then it will use that logic. Uses ExtractClaimsFromRequestContext by default
func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
@ -68,49 +60,13 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
for _, option := range options {
option(ce)
}
if ce.FromRequestContext == nil {
ce.FromRequestContext = ce.fromRequestContext
}
if ce.userIDClaim == "" {
ce.userIDClaim = UserIDClaim
}
return ce
}
// FromToken extracts claims from the token (after auth)
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{
Raw: claims,
}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return jwtClaims
}
jwtClaims.UserId = userID
accountIDClaim, ok := claims[c.authAudience+AccountIDSuffix]
if ok {
jwtClaims.AccountId = accountIDClaim.(string)
}
domainClaim, ok := claims[c.authAudience+DomainIDSuffix]
if ok {
jwtClaims.Domain = domainClaim.(string)
}
domainCategoryClaim, ok := claims[c.authAudience+DomainCategorySuffix]
if ok {
jwtClaims.DomainCategory = domainCategoryClaim.(string)
}
LastLoginClaimString, ok := claims[c.authAudience+LastLoginSuffix]
if ok {
jwtClaims.LastLogin = parseTime(LastLoginClaimString.(string))
}
invitedBool, ok := claims[c.authAudience+Invited]
if ok {
jwtClaims.Invited = invitedBool.(bool)
}
return jwtClaims
}
func parseTime(timeString string) time.Time {
if timeString == "" {
return time.Time{}
@ -122,11 +78,67 @@ func parseTime(timeString string) time.Time {
return parsedTime
}
// fromRequestContext extracts claims from the request context previously filled by the JWT token (after auth)
func (c *ClaimsExtractor) fromRequestContext(r *http.Request) AuthorizationClaims {
if r.Context().Value(TokenUserProperty) == nil {
return AuthorizationClaims{}
func (c ClaimsExtractor) audienceClaim(claimName string) string {
url, err := url.JoinPath(c.authAudience, claimName)
if err != nil {
return c.authAudience + claimName // as it was previously
}
token := r.Context().Value(TokenUserProperty).(*jwt.Token)
return c.FromToken(token)
return url
}
func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) {
claims := token.Claims.(jwt.MapClaims)
userAuth := nbcontext.UserAuth{}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return userAuth, errUserIDClaimEmpty
}
userAuth.UserId = userID
if accountIDClaim, ok := claims[c.audienceClaim(AccountIDSuffix)]; ok {
userAuth.AccountId = accountIDClaim.(string)
}
if domainClaim, ok := claims[c.audienceClaim(DomainIDSuffix)]; ok {
userAuth.Domain = domainClaim.(string)
}
if domainCategoryClaim, ok := claims[c.audienceClaim(DomainCategorySuffix)]; ok {
userAuth.DomainCategory = domainCategoryClaim.(string)
}
if lastLoginClaimString, ok := claims[c.audienceClaim(LastLoginSuffix)]; ok {
userAuth.LastLogin = parseTime(lastLoginClaimString.(string))
}
if invitedBool, ok := claims[c.audienceClaim(Invited)]; ok {
if value, ok := invitedBool.(bool); ok {
userAuth.Invited = value
}
}
return userAuth, nil
}
func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string {
claims := token.Claims.(jwt.MapClaims)
userJWTGroups := make([]string, 0)
if claim, ok := claims[claimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
} else {
log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
}
} else {
log.Debugf("JWT claim %q is not a string array", claimName)
}
return userJWTGroups
}

View File

@ -0,0 +1,302 @@
package jwt
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
)
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
}
// The supported elliptic curves types
const (
// p256 represents a cryptographic elliptical curve type.
p256 = "P-256"
// p384 represents a cryptographic elliptical curve type.
p384 = "P-384"
// p521 represents a cryptographic elliptical curve type.
p521 = "P-521"
)
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
X5c []string `json:"x5c"`
}
type Validator struct {
lock sync.Mutex
issuer string
audienceList []string
keysLocation string
idpSignkeyRefreshEnabled bool
keys *Jwks
}
var (
errKeyNotFound = errors.New("unable to find appropriate key")
errInvalidAudience = errors.New("invalid audience")
errInvalidIssuer = errors.New("invalid issuer")
errTokenEmpty = errors.New("required authorization token not found")
errTokenInvalid = errors.New("token is invalid")
errTokenParsing = errors.New("token could not be parsed")
)
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
keys, err := getPemKeys(keysLocation)
if err != nil {
log.WithField("keysLocation", keysLocation).Errorf("could not get keys from location: %s", err)
}
return &Validator{
keys: keys,
issuer: issuer,
audienceList: audienceList,
keysLocation: keysLocation,
idpSignkeyRefreshEnabled: idpSignkeyRefreshEnabled,
}
}
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
var checkAud bool
for _, audience := range v.audienceList {
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
if checkAud {
break
}
}
if !checkAud {
return token, errInvalidAudience
}
// Verify 'issuer' claim
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false)
if !checkIss {
return token, errInvalidIssuer
}
// If keys are rotated, verify the keys prior to token validation
if v.idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
// @todo propose a separate go routine to regularly check these to prevent blocking when actually
// validating the token
if !v.keys.stillValid() {
v.lock.Lock()
defer v.lock.Unlock()
refreshedKeys, err := getPemKeys(v.keysLocation)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = v.keys
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
v.keys = refreshedKeys
}
}
publicKey, err := getPublicKey(token, v.keys)
if err == nil {
return publicKey, nil
}
msg := fmt.Sprintf("getPublicKey error: %s", err)
if errors.Is(err, errKeyNotFound) && !v.idpSignkeyRefreshEnabled {
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
}
log.WithContext(ctx).Error(msg)
return nil, err
}
}
// ValidateAndParse validates the token and returns the parsed token
func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// If we get here, the required token is missing
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, errTokenEmpty
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx))
// Check if there was an error in parsing...
if err != nil {
err = fmt.Errorf("%w: %s", errTokenParsing, err)
log.WithContext(ctx).Error(err.Error())
return nil, err
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
log.WithContext(ctx).Debug(errTokenInvalid.Error())
return nil, errTokenInvalid
}
return parsedToken, nil
}
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
jwks := &Jwks{}
url, err := url.ParseRequestURI(keysLocation)
if err != nil {
return jwks, err
}
resp, err := http.Get(url.String())
if err != nil {
return jwks, err
}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil {
return jwks, err
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, nil
}
func getPublicKey(token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid {
continue
}
if len(jwks.Keys[k].X5c) != 0 {
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
if jwks.Keys[k].Kty == "RSA" {
return getPublicKeyFromRSA(jwks.Keys[k])
}
if jwks.Keys[k].Kty == "EC" {
return getPublicKeyFromECDSA(jwks.Keys[k])
}
}
return nil, errKeyNotFound
}
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
return nil, fmt.Errorf("ecdsa key incomplete")
}
var xCoordinate []byte
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
return nil, err
}
var yCoordinate []byte
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
return nil, err
}
publicKey = &ecdsa.PublicKey{}
var curve elliptic.Curve
switch jwk.Crv {
case p256:
curve = elliptic.P256()
case p384:
curve = elliptic.P384()
case p521:
curve = elliptic.P521()
}
publicKey.Curve = curve
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
return publicKey, nil
}
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, err
}
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, err
}
var n, e big.Int
e.SetBytes(decodedE)
n.SetBytes(decodedN)
return &rsa.PublicKey{
E: int(e.Int64()),
N: &n,
}, nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
return 0
}
return maxAge
}
}
return 0
}

View File

@ -0,0 +1,170 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
"hash/crc32"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/base62"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
var _ Manager = (*manager)(nil)
type Manager interface {
ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
MarkPATUsed(ctx context.Context, tokenID string) error
GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
}
type manager struct {
store store.Store
validator *nbjwt.Validator
extractor *nbjwt.ClaimsExtractor
}
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator := nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
claimsExtractor := nbjwt.NewClaimsExtractor(
nbjwt.WithAudience(audience),
nbjwt.WithUserIDClaim(userIdClaim),
)
return &manager{
store: store,
validator: jwtValidator,
extractor: claimsExtractor,
}
}
func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
token, err := m.validator.ValidateAndParse(ctx, value)
if err != nil {
return nbcontext.UserAuth{}, nil, err
}
userAuth, err := m.extractor.ToUserAuth(token)
if err != nil {
return nbcontext.UserAuth{}, nil, err
}
return userAuth, token, err
}
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return userAuth, err
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if settings != nil && settings.JWTGroupsEnabled {
userAuth.Groups = m.extractor.ToGroups(token, settings.JWTGroupsClaimName)
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
if !userHasAllowedGroup(allowedGroups, userAuth.Groups) {
return userAuth, fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}
return userAuth, nil
}
// MarkPATUsed marks a personal access token as used
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
}
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return nil, nil, "", "", err
}
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
return user, pat, domain, category, nil
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *manager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
if len(token) != types.PATLength {
return nil, nil, fmt.Errorf("PAT has incorrect length")
}
prefix := token[:len(types.PATPrefix)]
if prefix != types.PATPrefix {
return nil, nil, fmt.Errorf("PAT has wrong prefix")
}
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, fmt.Errorf("PAT checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, fmt.Errorf("PAT checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
var user *types.User
var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
if err != nil {
return err
}
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
return err
})
if err != nil {
return nil, nil, err
}
return user, pat, nil
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}

View File

@ -0,0 +1,54 @@
package auth
import (
"context"
"github.com/golang-jwt/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
)
var (
_ Manager = (*MockManager)(nil)
)
// @note really dislike this mocking approach but rather than have to do additional test refactoring.
type MockManager struct {
ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error)
EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error)
MarkPATUsedFunc func(ctx context.Context, tokenID string) error
GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
}
// EnsureUserAccessByJWTGroups implements Manager.
func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
if m.EnsureUserAccessByJWTGroupsFunc != nil {
return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token)
}
return nbcontext.UserAuth{}, nil
}
// GetPATInfo implements Manager.
func (m *MockManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
if m.GetPATInfoFunc != nil {
return m.GetPATInfoFunc(ctx, token)
}
return &types.User{}, &types.PersonalAccessToken{}, "", "", nil
}
// MarkPATUsed implements Manager.
func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error {
if m.MarkPATUsedFunc != nil {
return m.MarkPATUsedFunc(ctx, tokenID)
}
return nil
}
// ValidateAndParseToken implements Manager.
func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) {
if m.ValidateAndParseTokenFunc != nil {
return m.ValidateAndParseTokenFunc(ctx, value)
}
return nbcontext.UserAuth{}, &jwt.Token{}, nil
}

View File

@ -0,0 +1,407 @@
package auth_test
import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
account := &types.Account{
Id: "account_id",
Users: map[string]*types.User{"someUser": {
Id: "someUser",
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
UserID: "someUser",
HashedToken: encodedHashedToken,
},
},
}},
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
}
func TestAuthManager_MarkPATUsed(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
account := &types.Account{
Id: "account_id",
Users: map[string]*types.User{"someUser": {
Id: "someUser",
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
},
},
}},
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
err = manager.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = store.GetAccount(context.Background(), "account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
}
func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
userId := "user-id"
domain := "test.domain"
account := &types.Account{
Id: "account_id",
Domain: domain,
Users: map[string]*types.User{"someUser": {
Id: "someUser",
}},
Settings: &types.Settings{},
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
// this has been validated and parsed by ValidateAndParseToken
userAuth := nbcontext.UserAuth{
AccountId: account.Id,
Domain: domain,
UserId: userId,
DomainCategory: "test-category",
// Groups: []string{"group1", "group2"},
}
// these tests only assert groups are parsed from token as per account settings
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
t.Run("JWT groups disabled", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.NoError(t, err, "ensure user access by JWT groups failed")
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
})
t.Run("User impersonated", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.NoError(t, err, "ensure user access by JWT groups failed")
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
})
t.Run("User PAT", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.NoError(t, err, "ensure user access by JWT groups failed")
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
account.Settings.JWTGroupsEnabled = true
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err, "save account failed")
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.NoError(t, err, "ensure user access by JWT groups failed")
require.Len(t, userAuth.Groups, 0, "account missing groups claim name")
})
t.Run("JWT groups enabled without allowed groups", func(t *testing.T) {
account.Settings.JWTGroupsEnabled = true
account.Settings.JWTGroupsClaimName = "idp-groups"
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err, "save account failed")
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.NoError(t, err, "ensure user access by JWT groups failed")
require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
})
t.Run("User in allowed JWT groups", func(t *testing.T) {
account.Settings.JWTGroupsEnabled = true
account.Settings.JWTGroupsClaimName = "idp-groups"
account.Settings.JWTAllowGroups = []string{"group1"}
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err, "save account failed")
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.NoError(t, err, "ensure user access by JWT groups failed")
require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
})
t.Run("User not in allowed JWT groups", func(t *testing.T) {
account.Settings.JWTGroupsEnabled = true
account.Settings.JWTGroupsClaimName = "idp-groups"
account.Settings.JWTAllowGroups = []string{"not-a-group"}
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err, "save account failed")
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.Error(t, err, "ensure user access is not in allowed groups")
})
}
func TestAuthManager_ValidateAndParseToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Cache-Control", "max-age=30") // set a 30s expiry to these keys
http.ServeFile(w, r, "test_data/jwks.json")
}))
defer server.Close()
issuer := "http://issuer.local"
audience := "http://audience.local"
userIdClaim := "" // defaults to "sub"
// we're only testing with RSA256
keyData, _ := os.ReadFile("test_data/sample_key")
key, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData)
keyId := "test-key"
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
customClaim := func(name string) string {
return fmt.Sprintf("%s/%s", audience, name)
}
lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) //"2025-02-12T14:25:26.186Z"
tests := []struct {
name string
tokenFunc func() string
expected *nbcontext.UserAuth // nil indicates expected error
}{
{
name: "Valid with custom claims",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{audience},
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 1).Unix(),
"sub": "user-id|123",
customClaim(nbjwt.AccountIDSuffix): "account-id|567",
customClaim(nbjwt.DomainIDSuffix): "http://localhost",
customClaim(nbjwt.DomainCategorySuffix): "private",
customClaim(nbjwt.LastLoginSuffix): lastLogin.Format(time.RFC3339),
customClaim(nbjwt.Invited): false,
}
tokenString, _ := token.SignedString(key)
return tokenString
},
expected: &nbcontext.UserAuth{
UserId: "user-id|123",
AccountId: "account-id|567",
Domain: "http://localhost",
DomainCategory: "private",
LastLogin: lastLogin,
Invited: false,
},
},
{
name: "Valid without custom claims",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{audience},
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
"sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
return tokenString
},
expected: &nbcontext.UserAuth{
UserId: "user-id|123",
},
},
{
name: "Expired token",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{audience},
"iat": time.Now().Add(time.Hour * -2).Unix(),
"exp": time.Now().Add(time.Hour * -1).Unix(),
"sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
return tokenString
},
},
{
name: "Not yet valid",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{audience},
"iat": time.Now().Add(time.Hour).Unix(),
"exp": time.Now().Add(time.Hour * 2).Unix(),
"sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
return tokenString
},
},
{
name: "Invalid signature",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{audience},
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
"sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
parts := strings.Split(tokenString, ".")
parts[2] = "invalid-signature"
return strings.Join(parts, ".")
},
},
{
name: "Invalid issuer",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": "not-the-issuer",
"aud": []string{audience},
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
"sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
return tokenString
},
},
{
name: "Invalid audience",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{"not-the-audience"},
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
"sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
return tokenString
},
},
{
name: "Invalid user claim",
tokenFunc: func() string {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = keyId
token.Claims = jwt.MapClaims{
"iss": issuer,
"aud": []string{audience},
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
"not-sub": "user-id|123",
}
tokenString, _ := token.SignedString(key)
return tokenString
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenString := tt.tokenFunc()
userAuth, token, err := manager.ValidateAndParseToken(context.Background(), tokenString)
if tt.expected != nil {
assert.NoError(t, err)
assert.True(t, token.Valid)
assert.Equal(t, *tt.expected, userAuth)
} else {
assert.Error(t, err)
assert.Nil(t, token)
assert.Empty(t, userAuth)
}
})
}
}

View File

@ -0,0 +1,11 @@
{
"keys": [
{
"kty": "RSA",
"kid": "test-key",
"use": "sig",
"n": "4f5wg5l2hKsTeNem_V41fGnJm6gOdrj8ym3rFkEU_wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn_MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR-1DcKJzQBSTAGnpYVaqpsARap-nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7w",
"e": "AQAB"
}
]
}

View File

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn
SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i
cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC
PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR
ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA
Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3
n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy
MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9
POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE
KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM
IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn
FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY
mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj
FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U
I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs
2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn
/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT
OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86
EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+
hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0
4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb
mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry
eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3
CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+
9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq
-----END RSA PRIVATE KEY-----

View File

@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41
fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7
mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp
HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2
XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b
ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy
7wIDAQAB
-----END PUBLIC KEY-----

View File

@ -2,7 +2,6 @@ package server
import (
"net/netip"
"net/url"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
@ -180,9 +179,3 @@ type ReverseProxy struct {
// trusted IP prefixes.
TrustedPeers []netip.Prefix
}
// validateURL validates input http url
func validateURL(httpURL string) bool {
_, err := url.ParseRequestURI(httpURL)
return err == nil
}

View File

@ -0,0 +1,60 @@
package context
import (
"context"
"fmt"
"net/http"
"time"
)
type key int
const (
UserAuthContextKey key = iota
)
type UserAuth struct {
// The account id the user is accessing
AccountId string
// The account domain
Domain string
// The account domain category, TBC values
DomainCategory string
// Indicates whether this user was invited, TBC logic
Invited bool
// Indicates whether this is a child account
IsChild bool
// The user id
UserId string
// Last login time for this user
LastLogin time.Time
// The Groups the user belongs to on this account
Groups []string
// Indicates whether this user has authenticated with a Personal Access Token
IsPAT bool
}
func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) {
return GetUserAuthFromContext(r.Context())
}
func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request {
return r.WithContext(SetUserAuthInContext(r.Context(), userAuth))
}
func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) {
if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok {
return userAuth, nil
}
return UserAuth{}, fmt.Errorf("user auth not in context")
}
func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context {
//nolint
ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId)
//nolint
ctx = context.WithValue(ctx, AccountIDKey, userAuth.AccountId)
return context.WithValue(ctx, UserAuthContextKey, userAuth)
}

View File

@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
@ -39,11 +39,10 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager
config *Config
secretsManager SecretsManager
jwtValidator jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
authManager auth.Manager
}
// NewServer creates a new Management server
@ -56,29 +55,13 @@ func NewServer(
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager,
authManager auth.Manager,
) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
var jwtValidator jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
@ -89,16 +72,6 @@ func NewServer(
}
}
var audience, userIDClaim string
if config.HttpConfig != nil {
audience = config.HttpConfig.AuthAudience
userIDClaim = config.HttpConfig.AuthUserIDClaim
}
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@ -107,8 +80,7 @@ func NewServer(
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
authManager: authManager,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
}, nil
@ -294,26 +266,32 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
if s.authManager == nil {
return "", status.Errorf(codes.Internal, "missing auth manager")
}
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
userAuth, token, err := s.authManager.ValidateAndParseToken(ctx, jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
_, _, err = s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
userAuth, err = s.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, token)
if err != nil {
return "", status.Error(codes.PermissionDenied, err.Error())
}
return claims.UserId, nil
err = s.accountManager.SyncUserJWTGroups(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("gRPC server failed to sync user JWT groups: %s", err)
}
return userAuth.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {

View File

@ -11,9 +11,9 @@ import (
"github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@ -36,55 +35,51 @@ import (
const apiPrefix = "/api"
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
func NewAPIHandler(
ctx context.Context,
accountManager s.AccountManager,
networksManager nbnetworks.Manager,
resourceManager resources.Manager,
routerManager routers.Manager,
groupsManager nbgroups.Manager,
LocationManager geolocation.Geolocation,
authManager auth.Manager,
appMetrics telemetry.AppMetrics,
config *s.Config,
integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetPATInfo,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,
claimsExtractor,
authCfg.Audience,
authCfg.UserIDClaim,
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
)
corsMiddleware := cors.AllowAll()
claimsExtractor = jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,
accountManager.GetUser)
acMiddleware := middleware.NewAccessControl(accountManager.GetUserFromUserAuth)
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
accounts.AddEndpoints(accountManager, authCfg, router)
peers.AddEndpoints(accountManager, authCfg, router)
users.AddEndpoints(accountManager, authCfg, router)
setup_keys.AddEndpoints(accountManager, authCfg, router)
policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
groups.AddEndpoints(accountManager, authCfg, router)
routes.AddEndpoints(accountManager, authCfg, router)
dns.AddEndpoints(accountManager, authCfg, router)
events.AddEndpoints(accountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)
accounts.AddEndpoints(accountManager, router)
peers.AddEndpoints(accountManager, router)
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
groups.AddEndpoints(accountManager, router)
routes.AddEndpoints(accountManager, router)
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
return rootRouter, nil
}

View File

@ -9,47 +9,42 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
accountsHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
accountsHandler := newHandler(accountManager)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -62,13 +57,14 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
_, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
@ -125,7 +121,12 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
// deleteAccount is a HTTP DELETE handler to delete an account
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
@ -133,7 +134,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@ -13,19 +13,16 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
func initAccountsTestData(account *types.Account, admin *types.User) *handler {
func initAccountsTestData(account *types.Account) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account.Id, admin.Id, nil
},
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
},
@ -44,15 +41,6 @@ func initAccountsTestData(account *types.Account, admin *types.User) *handler {
return accCopy, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",
}
}),
),
}
}
@ -75,7 +63,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
PeerLoginExpiration: time.Hour,
RegularUsersViewBlocked: true,
},
}, adminUser)
})
tt := []struct {
name string
@ -191,6 +179,11 @@ func TestAccounts_AccountsHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: adminUser.Id,
AccountId: accountID,
Domain: "hotmail.com",
})
router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")

View File

@ -8,51 +8,44 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
// dnsSettingsHandler is a handler that returns the DNS settings of the account
type dnsSettingsHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
addDNSSettingEndpoint(accountManager, authCfg, router)
addDNSNameserversEndpoint(accountManager, authCfg, router)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
addDNSSettingEndpoint(accountManager, router)
addDNSNameserversEndpoint(accountManager, router)
}
func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg)
func addDNSSettingEndpoint(accountManager server.AccountManager, router *mux.Router) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler {
return &dnsSettingsHandler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
func newDNSSettingsHandler(accountManager server.AccountManager) *dnsSettingsHandler {
return &dnsSettingsHandler{accountManager: accountManager}
}
// getDNSSettings returns the DNS settings for the account
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -68,13 +61,14 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
// updateDNSSettings handles update to DNS settings of an account
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PutApiDnsSettingsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {

View File

@ -17,7 +17,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@ -52,19 +53,7 @@ func initDNSSettingsTestData() *dnsSettingsHandler {
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testDNSSettingsAccountID,
}
}),
),
}
}
@ -118,6 +107,11 @@ func TestDNSSettingsHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id,
AccountId: testingDNSSettingsAccount.Id,
Domain: testingDNSSettingsAccount.Domain,
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")

View File

@ -10,21 +10,19 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// nameserversHandler is the nameserver group handler of the account
type nameserversHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
nameserversHandler := newNameserversHandler(accountManager, authCfg)
func addDNSNameserversEndpoint(accountManager server.AccountManager, router *mux.Router) {
nameserversHandler := newNameserversHandler(accountManager)
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
@ -33,26 +31,21 @@ func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg con
}
// newNameserversHandler returns a new instance of nameserversHandler handler
func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler {
return &nameserversHandler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
func newNameserversHandler(accountManager server.AccountManager) *nameserversHandler {
return &nameserversHandler{accountManager: accountManager}
}
// getAllNameservers returns the list of nameserver groups for the account
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -69,13 +62,14 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
// createNameserverGroup handles nameserver group creation request
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiDnsNameserversJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -102,13 +96,14 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
// updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@ -153,13 +148,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
// deleteNameserverGroup handles nameserver group deletion request
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@ -177,14 +173,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
// getNameserverGroup handles a nameserver group Get request identified by ID
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)

View File

@ -18,7 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@ -81,19 +82,7 @@ func initNameserversTestData() *nameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testNSGroupAccountID,
}
}),
),
}
}
@ -204,6 +193,11 @@ func TestNameserversHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
AccountId: testNSGroupAccountID,
Domain: "hotmail.com",
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")

View File

@ -10,44 +10,37 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
// handler HTTP handler
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
eventsHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
eventsHandler := newHandler(accountManager)
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
func newHandler(accountManager server.AccountManager) *handler {
return &handler{accountManager: accountManager}
}
// getAllEvents list of the given account
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)

View File

@ -13,9 +13,10 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
)
@ -29,22 +30,10 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
}
return []*activity.Event{}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
return make(map[string]*types.UserInfo), nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",
}
}),
),
}
}
@ -199,6 +188,11 @@ func TestEvents_GetEvents(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_account",
})
router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")

View File

@ -7,24 +7,23 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns groups of the account
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
groupsHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
groupsHandler := newHandler(accountManager)
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
@ -33,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// newHandler creates a new groups handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
@ -75,13 +70,14 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
// updateGroup handles update to a group identified by a given ID
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
@ -164,13 +160,14 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
// createGroup handles group creation request
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiGroupsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -223,13 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
// deleteGroup handles group deletion request
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
@ -253,12 +251,13 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
// getGroup returns a group
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)

View File

@ -18,9 +18,9 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
@ -59,9 +59,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return group, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
@ -87,15 +84,6 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
@ -134,6 +122,11 @@ func TestGetGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
@ -255,6 +248,11 @@ func TestWriteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
@ -332,7 +330,11 @@ func TestDeleteGroup(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
router.ServeHTTP(recorder, req)

View File

@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus"
s "github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@ -31,16 +30,14 @@ type handler struct {
routerManager routers.Manager
accountManager s.AccountManager
groupsManager groups.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor
groupsManager groups.Manager
}
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
addRouterEndpoints(routerManager, extractFromToken, authCfg, router)
addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router)
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, router *mux.Router) {
addRouterEndpoints(routerManager, router)
addResourceEndpoints(resourceManager, groupsManager, router)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
@ -48,29 +45,25 @@ func AddEndpoints(networksManager networks.Manager, resourceManager resources.Ma
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
}
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager) *handler {
return &handler{
networksManager: networksManager,
resourceManager: resourceManager,
routerManager: routerManager,
groupsManager: groupsManager,
accountManager: accountManager,
extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
networksManager: networksManager,
resourceManager: resourceManager,
routerManager: routerManager,
groupsManager: groupsManager,
accountManager: accountManager,
}
}
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -105,12 +98,12 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
@ -141,12 +134,12 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
@ -179,13 +172,13 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@ -229,13 +222,13 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {

View File

@ -1,30 +1,26 @@
package networks
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
)
type resourceHandler struct {
resourceManager resources.Manager
groupsManager groups.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor
resourceManager resources.Manager
groupsManager groups.Manager
}
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg)
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
@ -33,26 +29,21 @@ func addResourceEndpoints(resourcesManager resources.Manager, groupsManager grou
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
}
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler {
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
return &resourceHandler{
resourceManager: resourceManager,
groupsManager: groupsManager,
extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
resourceManager: resourceManager,
groupsManager: groupsManager,
}
}
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
@ -76,13 +67,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -106,13 +98,14 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
}
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -144,13 +137,13 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
}
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
@ -171,13 +164,13 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
}
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -209,12 +202,12 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
}
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]

View File

@ -1,28 +1,24 @@
package networks
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
)
type routersHandler struct {
routersManager routers.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor
routersManager routers.Manager
}
func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg)
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
@ -30,25 +26,21 @@ func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ct
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
}
func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler {
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
return &routersHandler{
routersManager: routersManager,
extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
routersManager: routersManager,
}
}
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
@ -65,13 +57,14 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
@ -96,13 +89,14 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
@ -115,13 +109,14 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -146,13 +141,13 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
}
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)

View File

@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@ -22,12 +21,11 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
peersHandler := NewHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
peersHandler := NewHandler(accountManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@ -35,13 +33,9 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler {
func NewHandler(accountManager server.AccountManager) *Handler {
return &Handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
@ -149,12 +143,13 @@ func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peer
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@ -179,13 +174,14 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -230,13 +226,14 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPee
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {

View File

@ -15,8 +15,8 @@ import (
"github.com/gorilla/mux"
"golang.org/x/exp/maps"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
@ -25,16 +25,13 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
type ctxKey string
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
userIDKey ctxKey = "user_id"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
)
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
@ -146,9 +143,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) {
return account, nil
},
@ -167,16 +161,6 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return ok
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
@ -267,8 +251,11 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "admin_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@ -412,8 +399,11 @@ func TestGetAccessiblePeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: tc.callerUserID,
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")

View File

@ -13,9 +13,9 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
@ -43,23 +43,11 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
return &geolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return types.NewAdminUser(id), nil
},
},
geolocationManager: geo,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
@ -112,6 +100,11 @@ func TestGetCitiesByCountry(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
@ -200,6 +193,11 @@ func TestGetAllCountries(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")

View File

@ -7,11 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
@ -23,24 +22,19 @@ var (
type geolocationsHandler struct {
accountManager server.AccountManager
geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *geolocationsHandler {
return &geolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
@ -104,12 +98,13 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
return err
}
_, userID := userAuth.AccountId, userAuth.UserId
user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
return err

View File

@ -8,51 +8,46 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns policy of the account
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
policiesHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
policiesHandler := newHandler(accountManager)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
addPostureCheckEndpoint(accountManager, locationManager, authCfg, router)
addPostureCheckEndpoint(accountManager, locationManager, router)
}
// newHandler creates a new policies handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllPolicies list for the account
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -80,13 +75,14 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
// updatePolicy handles update to a policy identified by a given ID
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@ -105,13 +101,14 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
// createPolicy handles policy creation request
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
h.savePolicy(w, r, accountID, userID, "")
}
@ -306,13 +303,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
// deletePolicy handles policy deletion request
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@ -330,13 +327,14 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
// getPolicy handles a group Get request identified by ID
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {

View File

@ -13,8 +13,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@ -44,9 +44,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
return []*types.Group{{ID: "F"}, {ID: "G"}}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
user := types.NewAdminUser(userID)
return &types.Account{
@ -65,15 +62,6 @@ func initPoliciesTestData(policies ...*types.Policy) *handler {
}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
@ -115,6 +103,11 @@ func TestPoliciesGetPolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
@ -274,6 +267,11 @@ func TestPoliciesWritePolicy(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")

View File

@ -7,11 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
@ -20,40 +19,35 @@ import (
type postureChecksHandler struct {
accountManager server.AccountManager
geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
addLocationsEndpoint(accountManager, locationManager, authCfg, router)
addLocationsEndpoint(accountManager, locationManager, router)
}
// newPostureChecksHandler creates a new PostureChecks handler
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation) *postureChecksHandler {
return &postureChecksHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllPostureChecks list for the account
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -70,13 +64,14 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
// updatePostureCheck handles update to a posture check identified by a given ID
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@ -95,25 +90,26 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
// createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
p.savePostureChecks(w, r, accountID, userID, "")
}
// getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@ -132,13 +128,13 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
// deletePostureCheck handles posture check deletion request
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {

View File

@ -14,9 +14,9 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@ -66,20 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
}
return accountPostureChecks, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Mock{},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
@ -187,6 +175,11 @@ func TestGetPostureCheck(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
@ -835,6 +828,11 @@ func TestPostureCheckUpdate(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
defaultHandler := *p
if tc.setupHandlerFunc != nil {

View File

@ -10,10 +10,9 @@ import (
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
)
@ -22,12 +21,11 @@ const failedToConvertRoute = "failed to convert route to response: %v"
// handler is the routes handler of the account
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
routesHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
routesHandler := newHandler(accountManager)
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
@ -36,25 +34,22 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// newHandler returns a new instance of routes handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllRoutes returns the list of routes for the account
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -75,13 +70,14 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
// createRoute handles route creation request
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.PostApiRoutesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -172,13 +168,13 @@ func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
// updateRoute handles update to a route identified by a given ID
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
@ -265,13 +261,13 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
// deleteRoute handles route deletion request
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
@ -289,13 +285,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
// getRoute handles a route Get request identified by ID
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)

View File

@ -16,12 +16,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
)
@ -60,32 +58,6 @@ var baseExistingRoute = &route.Route{
Groups: []string{existingGroupID},
}
var testingAccount = &types.Account{
Id: testAccountID,
Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{
existingPeerID: {
Key: existingPeerKey,
IP: netip.MustParseAddr(existingPeerIP1).AsSlice(),
ID: existingPeerID,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
},
},
nonLinuxExistingPeerID: {
Key: nonLinuxExistingPeerID,
IP: netip.MustParseAddr(existingPeerIP2).AsSlice(),
ID: nonLinuxExistingPeerID,
Meta: nbpeer.PeerSystemMeta{
GoOS: "darwin",
},
},
},
Users: map[string]*types.User{
"test_user": types.NewAdminUser("test_user"),
},
}
func initRoutesTestData() *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
@ -150,20 +122,7 @@ func initRoutesTestData() *handler {
}
return nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
// return testingAccount, testingAccount.Users["test_user"], nil
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,
}
}),
),
}
}
@ -526,6 +485,11 @@ func TestRoutesHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")

View File

@ -10,22 +10,20 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns a list of setup keys of the account
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
keysHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
keysHandler := newHandler(accountManager)
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
@ -34,25 +32,21 @@ func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg,
}
// newHandler creates a new setup key handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// createSetupKey is a POST requests that creates a new SetupKey
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -108,12 +102,12 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
// getSetupKey is a GET request to get a SetupKey by ID
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
@ -133,13 +127,13 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
// updateSetupKey is a PUT request to update server.SetupKey
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@ -174,13 +168,13 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
// getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -196,13 +190,13 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {

View File

@ -14,8 +14,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@ -28,14 +28,9 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
user *types.User,
) *handler {
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool, allowExtraDNSLabels bool,
) (*types.SetupKey, error) {
@ -76,15 +71,6 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe
return status.Errorf(status.NotFound, "key %s not found", keyID)
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: user.Id,
Domain: "hotmail.com",
AccountId: "testAccountId",
}
}),
),
}
}
@ -171,12 +157,17 @@ func TestSetupKeysHandlers(t *testing.T) {
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: adminUser.Id,
Domain: "hotmail.com",
AccountId: "testAccountId",
})
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")

View File

@ -7,22 +7,20 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// patHandler is the nameserver group handler of the account
type patHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
tokenHandler := newPATsHandler(accountManager, authCfg)
func addUsersTokensEndpoint(accountManager server.AccountManager, router *mux.Router) {
tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
@ -30,25 +28,21 @@ func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg config
}
// newPATsHandler creates a new patHandler HTTP handler
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
func newPATsHandler(accountManager server.AccountManager) *patHandler {
return &patHandler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
@ -72,13 +66,13 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@ -103,13 +97,13 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@ -135,13 +129,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {

View File

@ -12,11 +12,12 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/util"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/util"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@ -77,10 +78,6 @@ func initPATTestData() *patHandler {
PersonalAccessToken: types.PersonalAccessToken{},
}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
@ -115,15 +112,6 @@ func initPATTestData() *patHandler {
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
}
}),
),
}
}
@ -185,6 +173,11 @@ func TestTokenHandlers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")

View File

@ -9,39 +9,33 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
// handler is a handler that returns users of the account
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
accountManager server.AccountManager
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
userHandler := newHandler(accountManager, authCfg)
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
addUsersTokensEndpoint(accountManager, authCfg, router)
addUsersTokensEndpoint(accountManager, router)
}
// newHandler creates a new UsersHandler HTTP handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
func newHandler(accountManager server.AccountManager) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
}
// deleteUser is a DELETE request to delete a user
@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
}
// getAllUsers returns a list of users of the account this user belongs to.
@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, claims.UserId))
users = append(users, toUserResponse(d, userID))
continue
}
@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, claims.UserId))
users = append(users, toUserResponse(d, userID))
}
}
@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]

View File

@ -13,8 +13,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
@ -64,9 +64,6 @@ var usersTestAccount = &types.Account{
func initUsersTestData() *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return usersTestAccount.Users[id], nil
},
@ -127,15 +124,6 @@ func initUsersTestData() *handler {
return nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
}
}),
),
}
}
@ -158,6 +146,11 @@ func TestGetUsers(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
})
userHandler.getAllUsers(recorder, req)
@ -263,6 +256,11 @@ func TestUpdateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
@ -355,6 +353,11 @@ func TestCreateUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder()
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
})
userHandler.createUser(rr, req)
@ -399,6 +402,12 @@ func TestInviteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
})
rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req)
@ -452,6 +461,12 @@ func TestDeleteUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
})
rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req)

View File

@ -7,30 +7,24 @@ import (
log "github.com/sirupsen/logrus"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
type GetUser func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
claimsExtract jwtclaims.ClaimsExtractor
getUser GetUser
getUser GetUser
}
// NewAccessControl instance constructor
func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl {
func NewAccessControl(getUser GetUser) *AccessControl {
return &AccessControl{
claimsExtract: *jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
),
getUser: getUser,
}
}
@ -45,12 +39,16 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
return
}
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(r.Context(), claims)
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
log.WithContext(r.Context()).Errorf("failed to get user auth from request: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
}
user, err := a.getUser(r.Context(), userAuth)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid user auth"), w)
return
}

View File

@ -8,67 +8,41 @@ import (
"strings"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(ctx context.Context, token string) error
// CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
claimsExtractor *jwtclaims.ClaimsExtractor
audience string
userIDClaim string
authManager auth.Manager
ensureAccount EnsureAccountFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
}
const (
userProperty = "user"
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}
func NewAuthMiddleware(
authManager auth.Manager,
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
) *AuthMiddleware {
return &AuthMiddleware{
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
claimsExtractor: claimsExtractor,
audience: audience,
userIDClaim: userIdClaim,
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
@ -84,108 +58,111 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
err := m.checkJWTFromRequest(w, r, auth)
request, err := m.checkJWTFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, request)
case "token":
err := m.checkPATFromRequest(w, r, auth)
request, err := m.checkPATFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, request)
default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return r, fmt.Errorf("error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(r.Context(), token)
ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil {
return err
return r, err
}
if validatedToken == nil {
return nil
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = ok
}
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
return err
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return r, err
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(ctx, authClaims)
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return r, err
}
err = m.syncUserJWTGroups(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(auth)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
return r, fmt.Errorf("error extracting token: %w", err)
}
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
return r, fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.GetExpirationDate()) {
return fmt.Errorf("token expired")
return r, fmt.Errorf("token expired")
}
err = m.markPATUsed(r.Context(), pat.ID)
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
return err
return r, err
}
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
*r = *newRequest
return nil
userAuth := nbcontext.UserAuth{
UserId: user.Id,
AccountId: user.AccountID,
Domain: accDomain,
DomainCategory: accCategory,
IsPAT: true,
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
return "", errors.New("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
@ -195,7 +172,7 @@ func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
// the PAT token from the Authorization header.
func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
return "", errors.New("Authorization header format must be Token {token}")
return "", errors.New("authorization header format must be Token {token}")
}
return authHeaderParts[1], nil

View File

@ -9,10 +9,14 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
@ -58,17 +62,23 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
if token == JWT {
return &jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + jwtclaims.AccountIDSuffix: accountID,
return nbcontext.UserAuth{
UserId: userID,
AccountId: accountID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
},
Valid: true,
}, nil
&jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + nbjwt.AccountIDSuffix: accountID,
},
Valid: true,
}, nil
}
return nil, fmt.Errorf("JWT invalid")
return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(_ context.Context, token string) error {
@ -78,16 +88,20 @@ func mockMarkPATUsed(_ context.Context, token string) error {
return fmt.Errorf("Should never get reached")
}
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
return userAuth, nil
}
if _, ok := testAccount.Users[claims.UserId]; !ok {
return fmt.Errorf("user with id %s does not exist", claims.UserId)
if testAccount.Id != userAuth.AccountId {
return userAuth, fmt.Errorf("account with id %s does not exist", userAuth.AccountId)
}
return nil
if _, ok := testAccount.Users[userAuth.UserId]; !ok {
return userAuth, fmt.Errorf("user with id %s does not exist", userAuth.UserId)
}
return userAuth, nil
}
func TestAuthMiddleware_Handler(t *testing.T) {
@ -158,22 +172,24 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// do nothing
})
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
authMiddleware := NewAuthMiddleware(
mockGetAccountInfoFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockCheckUserAccessByJWTGroups,
claimsExtractor,
audience,
userIDClaim,
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
)
handlerToTest := authMiddleware.Handler(nextHandler)
@ -195,9 +211,115 @@ func TestAuthMiddleware_Handler(t *testing.T) {
result := rec.Result()
defer result.Body.Close()
if result.StatusCode != tc.expectedStatusCode {
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, result.StatusCode)
}
})
}
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
path string
authHeader string
expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status
}{
{
name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsPAT: true,
},
},
{
name: "Valid PAT Token ignores child",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsPAT: true,
},
},
{
name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
},
},
{
name: "Valid JWT Token with child",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbcontext.UserAuth{
AccountId: "xyz",
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsChild: true,
},
},
}
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if tc.expectedUserAuth != nil {
assert.NoError(t, err)
assert.Equal(t, *tc.expectedUserAuth, userAuth)
} else {
assert.Error(t, err)
assert.Empty(t, userAuth)
}
}))
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)
req.Header.Set("Authorization", tc.authHeader)
rec := httptest.NewRecorder()
handlerToTest.ServeHTTP(rec, req)
result := rec.Result()
defer result.Body.Close()
if tc.expectedUserAuth != nil {
assert.Equal(t, 200, result.StatusCode)
} else {
assert.Equal(t, 401, result.StatusCode)
}
})
}
}

View File

@ -77,13 +77,13 @@ func BenchmarkUpdatePeer(b *testing.B) {
func BenchmarkGetOnePeer(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70},
"Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30},
"Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
"Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
"Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130},
"Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
"Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
"Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 70},
"Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
"Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
"Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 200},
"Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750},
}
@ -111,9 +111,9 @@ func BenchmarkGetOnePeer(b *testing.B) {
func BenchmarkGetAllPeers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150},
"Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30},
"Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70},
"Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
"Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
"Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 100},
"Peers - L": {MinMsPerOpLocal: 110, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
"Groups - L": {MinMsPerOpLocal: 150, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 500},
"Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400},

View File

@ -48,13 +48,12 @@ func BenchmarkUpdateUser(b *testing.B) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
@ -97,13 +96,12 @@ func BenchmarkGetOneUser(b *testing.B) {
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
@ -118,26 +116,25 @@ func BenchmarkGetOneUser(b *testing.B) {
func BenchmarkGetAllUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 75},
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 10, MaxMsPerOpCICD: 100},
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 300},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
@ -152,26 +149,25 @@ func BenchmarkGetAllUsers(b *testing.B) {
func BenchmarkDeleteUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 50},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {

View File

@ -3,6 +3,7 @@ package testing_tools
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
@ -13,17 +14,17 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/management/server/util"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@ -32,6 +33,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
const (
@ -115,11 +117,20 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
t.Fatalf("Failed to create manager: %v", err)
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
GetPATInfoFunc: authManager.GetPATInfo,
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock)
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, &server.Config{}, validatorMock)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@ -309,3 +320,25 @@ func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration,
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected)
}
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
userAuth.UserId = token
userAuth.AccountId = "testAccountId"
userAuth.Domain = "test.com"
userAuth.DomainCategory = "private"
case "otherUserId":
userAuth.UserId = "otherUserId"
userAuth.AccountId = "otherAccountId"
userAuth.Domain = "other.com"
userAuth.DomainCategory = "private"
case "invalidToken":
return userAuth, nil, errors.New("invalid token")
}
jwtToken := jwt.New(jwt.SigningMethodHS256)
return userAuth, jwtToken, nil
}

View File

@ -1,19 +0,0 @@
package jwtclaims
import (
"time"
"github.com/golang-jwt/jwt"
)
// AuthorizationClaims stores authorization information from JWTs
type AuthorizationClaims struct {
UserId string
AccountId string
Domain string
DomainCategory string
LastLogin time.Time
Invited bool
Raw jwt.MapClaims
}

View File

@ -1,227 +0,0 @@
package jwtclaims
import (
"context"
"net/http"
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
)
func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audience string) *http.Request {
t.Helper()
const layout = "2006-01-02T15:04:05.999Z"
claimMaps := jwt.MapClaims{}
if claims.UserId != "" {
claimMaps[UserIDClaim] = claims.UserId
}
if claims.AccountId != "" {
claimMaps[audience+AccountIDSuffix] = claims.AccountId
}
if claims.Domain != "" {
claimMaps[audience+DomainIDSuffix] = claims.Domain
}
if claims.DomainCategory != "" {
claimMaps[audience+DomainCategorySuffix] = claims.DomainCategory
}
if claims.LastLogin != (time.Time{}) {
claimMaps[audience+LastLoginSuffix] = claims.LastLogin.Format(layout)
}
if claims.Invited {
claimMaps[audience+Invited] = true
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
require.NoError(t, err, "creating testing request failed")
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
return testRequest
}
func TestExtractClaimsFromRequestContext(t *testing.T) {
type test struct {
name string
inputAuthorizationClaims AuthorizationClaims
inputAudiance string
testingFunc require.ComparisonAssertionFunc
expectedMSG string
}
const layout = "2006-01-02T15:04:05.999Z"
lastLogin, _ := time.Parse(layout, "2023-08-17T09:30:40.465Z")
testCase1 := test{
name: "All Claim Fields",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
LastLogin: lastLogin,
DomainCategory: "public",
Invited: true,
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc",
"https://login/nb_last_login": lastLogin.Format(layout),
"sub": "test",
"https://login/" + Invited: true,
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase2 := test{
name: "Domain Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
AccountId: "testAcc",
Raw: jwt.MapClaims{
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase3 := test{
name: "Account ID Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase4 := test{
name: "Category Is Empty",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
testCase5 := test{
name: "Only User ID Is set",
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Raw: jwt.MapClaims{
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
request := newTestRequestWithJWT(t, testCase.inputAuthorizationClaims, testCase.inputAudiance)
extractor := NewClaimsExtractor(WithAudience(testCase.inputAudiance))
extractedClaims := extractor.FromRequestContext(request)
testCase.testingFunc(t, testCase.inputAuthorizationClaims, extractedClaims, testCase.expectedMSG)
})
}
}
func TestExtractClaimsSetOptions(t *testing.T) {
t.Helper()
type test struct {
name string
extractor *ClaimsExtractor
check func(t *testing.T, c test)
}
testCase1 := test{
name: "No custom options",
extractor: NewClaimsExtractor(),
check: func(t *testing.T, c test) {
t.Helper()
if c.extractor.authAudience != "" {
t.Error("audience should be empty")
return
}
if c.extractor.userIDClaim != UserIDClaim {
t.Errorf("user id claim should be default, expected %s, got %s", UserIDClaim, c.extractor.userIDClaim)
return
}
if c.extractor.FromRequestContext == nil {
t.Error("from request context should not be nil")
return
}
},
}
testCase2 := test{
name: "Custom audience",
extractor: NewClaimsExtractor(WithAudience("https://login/")),
check: func(t *testing.T, c test) {
t.Helper()
if c.extractor.authAudience != "https://login/" {
t.Errorf("audience expected %s, got %s", "https://login/", c.extractor.authAudience)
return
}
},
}
testCase3 := test{
name: "Custom user id claim",
extractor: NewClaimsExtractor(WithUserIDClaim("customUserId")),
check: func(t *testing.T, c test) {
t.Helper()
if c.extractor.userIDClaim != "customUserId" {
t.Errorf("user id claim expected %s, got %s", "customUserId", c.extractor.userIDClaim)
return
}
},
}
testCase4 := test{
name: "Custom extractor from request context",
extractor: NewClaimsExtractor(
WithFromRequestContext(func(r *http.Request) AuthorizationClaims {
return AuthorizationClaims{
UserId: "testCustomRequest",
}
})),
check: func(t *testing.T, c test) {
t.Helper()
claims := c.extractor.FromRequestContext(&http.Request{})
if claims.UserId != "testCustomRequest" {
t.Errorf("user id claim expected %s, got %s", "testCustomRequest", claims.UserId)
return
}
},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
t.Run(testCase.name, func(t *testing.T) {
testCase.check(t, testCase)
})
}
}

View File

@ -1,349 +0,0 @@
package jwtclaims
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
)
// Options is a struct for specifying configuration options for the middleware.
type Options struct {
// The function that will return the Key to validate the JWT.
// It can be either a shared secret or a public key.
// Default value: nil
ValidationKeyGetter jwt.Keyfunc
// The name of the property in the request where the user information
// from the JWT will be stored.
// Default value: "user"
UserProperty string
// The function that will be called when there's an error validating the token
// Default value:
CredentialsOptional bool
// A function that extracts the token from the request
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
Debug bool
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
}
// The supported elliptic curves types
const (
// p256 represents a cryptographic elliptical curve type.
p256 = "P-256"
// p384 represents a cryptographic elliptical curve type.
p384 = "P-384"
// p521 represents a cryptographic elliptical curve type.
p521 = "P-521"
)
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
X5c []string `json:"x5c"`
}
type JWTValidator interface {
ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
}
// jwtValidatorImpl struct to handle token validation and parsing
type jwtValidatorImpl struct {
options Options
}
var keyNotFound = errors.New("unable to find appropriate key")
// NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation)
if err != nil {
return nil, err
}
var lock sync.Mutex
options := Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
var checkAud bool
for _, audience := range audienceList {
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
if checkAud {
break
}
}
if !checkAud {
return token, errors.New("invalid audience")
}
// Verify 'issuer' claim
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(issuer, false)
if !checkIss {
return token, errors.New("invalid issuer")
}
// If keys are rotated, verify the keys prior to token validation
if idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
if !keys.stillValid() {
lock.Lock()
defer lock.Unlock()
refreshedKeys, err := getPemKeys(ctx, keysLocation)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = keys
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
keys = refreshedKeys
}
}
publicKey, err := getPublicKey(ctx, token, keys)
if err == nil {
return publicKey, nil
}
msg := fmt.Sprintf("getPublicKey error: %s", err)
if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled {
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
}
log.WithContext(ctx).Error(msg)
return nil, err
},
EnableAuthOnOptions: false,
}
if options.UserProperty == "" {
options.UserProperty = "user"
}
return &jwtValidatorImpl{
options: options,
}, nil
}
// ValidateAndParse validates the token and returns the parsed token
func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.options.CredentialsOptional {
log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil //nolint:nilnil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, errors.New(errorMsg)
}
// Now parse the token
parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
// Check if there was an error in parsing...
if err != nil {
log.WithContext(ctx).Errorf("error parsing token: %v", err)
return nil, fmt.Errorf("error parsing token: %w", err)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.WithContext(ctx).Debug(errorMsg)
return nil, errors.New(errorMsg)
}
return parsedToken, nil
}
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
return nil, err
}
defer resp.Body.Close()
jwks := &Jwks{}
err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil {
return jwks, err
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err
}
func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid {
continue
}
if len(jwks.Keys[k].X5c) != 0 {
cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
if jwks.Keys[k].Kty == "RSA" {
log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
return getPublicKeyFromRSA(jwks.Keys[k])
}
if jwks.Keys[k].Kty == "EC" {
log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
return getPublicKeyFromECDSA(jwks.Keys[k])
}
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
}
return nil, keyNotFound
}
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
return nil, fmt.Errorf("ecdsa key incomplete")
}
var xCoordinate []byte
if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
return nil, err
}
var yCoordinate []byte
if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
return nil, err
}
publicKey = &ecdsa.PublicKey{}
var curve elliptic.Curve
switch jwk.Crv {
case p256:
curve = elliptic.P256()
case p384:
curve = elliptic.P384()
case p521:
curve = elliptic.P521()
}
publicKey.Curve = curve
publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
return publicKey, nil
}
func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, err
}
decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, err
}
var n, e big.Int
e.SetBytes(decodedE)
n.SetBytes(decodedN)
return &rsa.PublicKey{
E: int(e.Int64()),
N: &n,
}, nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
for _, directive := range directives {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
// Extract the max-age value
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.WithContext(ctx).Debugf("error parsing max-age: %v", err)
return 0
}
return maxAge
}
}
return 0
}
type JwtValidatorMock struct{}
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
claimMaps := jwt.MapClaims{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
claimMaps[UserIDClaim] = token
claimMaps[AccountIDSuffix] = "testAccountId"
claimMaps[DomainIDSuffix] = "test.com"
claimMaps[DomainCategorySuffix] = "private"
case "otherUserId":
claimMaps[UserIDClaim] = "otherUserId"
claimMaps[AccountIDSuffix] = "otherAccountId"
claimMaps[DomainIDSuffix] = "other.com"
claimMaps[DomainCategorySuffix] = "private"
case "invalidToken":
return nil, errors.New("invalid token")
}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
return jwtToken, nil
}

View File

@ -440,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr, nil)
if err != nil {
return nil, nil, "", cleanup, err
}

View File

@ -204,6 +204,7 @@ func startServer(
secretsManager,
nil,
nil,
nil,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)

View File

@ -13,14 +13,16 @@ import (
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
var _ server.AccountManager = (*MockAccountManager)(nil)
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
@ -29,7 +31,7 @@ type MockAccountManager struct {
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
@ -54,8 +56,6 @@ type MockAccountManager struct {
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
@ -80,8 +80,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
@ -240,14 +239,6 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
if am.GetPATInfoFunc != nil {
return am.GetPATInfoFunc(ctx, pat)
}
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
if am.DeleteAccountFunc != nil {
@ -256,14 +247,6 @@ func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, user
return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented")
}
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error {
if am.MarkPATUsedFunc != nil {
return am.MarkPATUsedFunc(ctx, pat)
}
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {
@ -430,11 +413,11 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
}
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
if am.GetUserFunc != nil {
return am.GetUserFunc(ctx, claims)
func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
if am.GetUserFromUserAuthFunc != nil {
return am.GetUserFromUserAuthFunc(ctx, userAuth)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method GetUserFromUserAuth is not implemented")
}
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
@ -614,19 +597,11 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if am.GetAccountIDFromTokenFunc != nil {
return am.GetAccountIDFromTokenFunc(ctx, claims)
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
if am.GetAccountIDFromUserAuthFunc != nil {
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
}
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
}
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
if am.CheckUserAccessByJWTGroupsFunc != nil {
return am.CheckUserAccessByJWTGroupsFunc(ctx, claims)
}
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromUserAuth is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
@ -859,3 +834,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco
}
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
}
func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented")
}

View File

@ -8,16 +8,16 @@ import (
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
log "github.com/sirupsen/logrus"
)
// createServiceUser creates a new service user under the given account.
@ -174,31 +174,26 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
}
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
// GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
if err != nil {
return nil, err
}
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
newLogin := user.LastDashboardLoginChanged(userAuth.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, userAuth.AccountId, userAuth.UserId, userAuth.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}
if newLogin {
meta := map[string]any{"timestamp": claims.LastLogin}
am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
meta := map[string]any{"timestamp": userAuth.LastLogin}
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, userAuth.AccountId, activity.DashboardLogin, meta)
}
return user, nil

View File

@ -10,6 +10,8 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/util"
"golang.org/x/exp/maps"
@ -25,7 +27,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
@ -925,11 +926,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
eventStore: &activity.InMemoryEventStore{},
}
claims := jwtclaims.AuthorizationClaims{
UserId: mockUserID,
claims := nbcontext.UserAuth{
UserId: mockUserID,
AccountId: mockAccountID,
}
user, err := am.GetUser(context.Background(), claims)
user, err := am.GetUserFromUserAuth(context.Background(), claims)
if err != nil {
t.Fatalf("Error when checking user role: %s", err)
}