refactor jwtValidator and geo db to interfaces + first component test for setup-keys

This commit is contained in:
Pascal Fischer
2024-11-22 16:16:52 +01:00
parent 1bbabf70b0
commit 832e168869
14 changed files with 300 additions and 56 deletions

View File

@@ -264,7 +264,7 @@ var (
KeysLocation: config.HttpConfig.AuthKeysLocation, KeysLocation: config.HttpConfig.AuthKeysLocation,
} }
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }

View File

@@ -166,8 +166,8 @@ type DefaultAccountManager struct {
cacheManager cache.CacheInterface[[]*idp.UserData] cacheManager cache.CacheInterface[[]*idp.UserData]
externalCacheManager ExternalCacheManager externalCacheManager ExternalCacheManager
ctx context.Context ctx context.Context
eventStore activity.Store EventStore activity.Store
geo *geolocation.Geolocation geo geolocation.Geolocation
requestBuffer *AccountRequestBuffer requestBuffer *AccountRequestBuffer
@@ -1041,7 +1041,7 @@ func BuildManager(
singleAccountModeDomain string, singleAccountModeDomain string,
dnsDomain string, dnsDomain string,
eventStore activity.Store, eventStore activity.Store,
geo *geolocation.Geolocation, geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool, userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator, integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics, metrics telemetry.AppMetrics,
@@ -1055,7 +1055,7 @@ func BuildManager(
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain, dnsDomain: dnsDomain,
eventStore: eventStore, EventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,

View File

@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
} }
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true) events, err := am.EventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
go func() { go func() {
_, err := am.eventStore.Save(ctx, &activity.Event{ _, err := am.EventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
Activity: activityID, Activity: activityID,
InitiatorID: initiatorID, InitiatorID: initiatorID,

View File

@@ -14,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
accountID string, count int) { accountID string, count int) {
t.Helper() t.Helper()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
_, err := manager.eventStore.Save(context.Background(), &activity.Event{ _, err := manager.EventStore.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(), Timestamp: time.Now().UTC(),
Activity: typ, Activity: typ,
InitiatorID: initiatorID, InitiatorID: initiatorID,
@@ -41,7 +41,7 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
return return
} }
assert.Len(t, events, 0) assert.Len(t, events, 0)
_ = manager.eventStore.Close(context.Background()) //nolint _ = manager.EventStore.Close(context.Background()) //nolint
}) })
t.Run("get events", func(t *testing.T) { t.Run("get events", func(t *testing.T) {
@@ -52,7 +52,7 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
} }
assert.Len(t, events, 10) assert.Len(t, events, 10)
_ = manager.eventStore.Close(context.Background()) //nolint _ = manager.EventStore.Close(context.Background()) //nolint
}) })
t.Run("get events without duplicates", func(t *testing.T) { t.Run("get events without duplicates", func(t *testing.T) {
@@ -62,6 +62,6 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
return return
} }
assert.Len(t, events, 1) assert.Len(t, events, 1)
_ = manager.eventStore.Close(context.Background()) //nolint _ = manager.EventStore.Close(context.Background()) //nolint
}) })
} }

View File

@@ -14,7 +14,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type Geolocation struct { type Geolocation interface {
Lookup(ip net.IP) (*Record, error)
GetAllCountries() ([]Country, error)
GetCitiesByCountry(countryISOCode string) ([]City, error)
Stop() error
}
type GeolocationImpl struct {
mmdbPath string mmdbPath string
mux sync.RWMutex mux sync.RWMutex
db *maxminddb.Reader db *maxminddb.Reader
@@ -54,7 +61,7 @@ const (
geonamesdbPattern = "geonames_*.db" geonamesdbPattern = "geonames_*.db"
) )
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) {
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
if err != nil { if err != nil {
@@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol
return nil, err return nil, err
} }
geo := &Geolocation{ geo := &GeolocationImpl{
mmdbPath: mmdbPath, mmdbPath: mmdbPath,
mux: sync.RWMutex{}, mux: sync.RWMutex{},
db: db, db: db,
@@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
return db, nil return db, nil
} }
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { func (gl *GeolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock() gl.mux.RLock()
defer gl.mux.RUnlock() defer gl.mux.RUnlock()
@@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
} }
// GetAllCountries retrieves a list of all countries. // GetAllCountries retrieves a list of all countries.
func (gl *Geolocation) GetAllCountries() ([]Country, error) { func (gl *GeolocationImpl) GetAllCountries() ([]Country, error) {
allCountries, err := gl.locationDB.GetAllCountries() allCountries, err := gl.locationDB.GetAllCountries()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) {
} }
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code. // GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) { func (gl *GeolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) {
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode) allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error)
return cities, nil return cities, nil
} }
func (gl *Geolocation) Stop() error { func (gl *GeolocationImpl) Stop() error {
close(gl.stopCh) close(gl.stopCh)
if gl.db != nil { if gl.db != nil {
if err := gl.db.Close(); err != nil { if err := gl.db.Close(); err != nil {
@@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin
} }
return nil return nil
} }
type GeolocationMock struct{}
func (g *GeolocationMock) Lookup(ip net.IP) (*Record, error) {
return &Record{}, nil
}
func (g *GeolocationMock) GetAllCountries() ([]Country, error) {
return []Country{}, nil
}
func (g *GeolocationMock) GetCitiesByCountry(countryISOCode string) ([]City, error) {
return []City{}, nil
}
func (g *GeolocationMock) Stop() error {
return nil
}

View File

@@ -35,7 +35,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *Config config *Config
secretsManager SecretsManager secretsManager SecretsManager
jwtValidator *jwtclaims.JWTValidator jwtValidator jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
@@ -57,7 +57,7 @@ func NewServer(
return nil, err return nil, err
} }
var jwtValidator *jwtclaims.JWTValidator var jwtValidator jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator( jwtValidator, err = jwtclaims.NewJWTValidator(

View File

@@ -21,12 +21,12 @@ var (
// GeolocationsHandler is a handler that returns locations. // GeolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct { type GeolocationsHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewGeolocationsHandlerHandler creates a new Geolocations handler // NewGeolocationsHandlerHandler creates a new Geolocations handler
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
return &GeolocationsHandler{ return &GeolocationsHandler{
accountManager: accountManager, accountManager: accountManager,
geolocationManager: geolocationManager, geolocationManager: geolocationManager,

View File

@@ -31,7 +31,7 @@ type AuthCfg struct {
type apiHandler struct { type apiHandler struct {
Router *mux.Router Router *mux.Router
AccountManager s.AccountManager AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager geolocation.Geolocation
AuthCfg AuthCfg AuthCfg AuthCfg
} }
@@ -40,7 +40,7 @@ type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor( claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),

View File

@@ -18,12 +18,12 @@ import (
// PostureChecksHandler is a handler that returns posture checks of the account. // PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct { type PostureChecksHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewPostureChecksHandler creates a new PostureChecks handler // NewPostureChecksHandler creates a new PostureChecks handler
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler {
return &PostureChecksHandler{ return &PostureChecksHandler{
accountManager: accountManager, accountManager: accountManager,
geolocationManager: geolocationManager, geolocationManager: geolocationManager,

View File

@@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
}, },
}, },
geolocationManager: &geolocation.Geolocation{}, geolocationManager: &geolocation.GeolocationImpl{},
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{

View File

@@ -0,0 +1,159 @@
//go:build component
package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
testAccountId = "testUserId"
testUserId = "testAccountId"
newKeyName = "newKey"
expiresIn = 3600
existingKeyName = "existingKey"
)
func Test_SetupKeys_Create_Success(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedSetupKey *api.SetupKey
requestBody *api.CreateSetupKeyRequest
requestType string
requestPath string
}{
{
name: "Create Setup Key",
requestType: http.MethodPost,
requestPath: "/api/setup-keys",
requestBody: &api.CreateSetupKeyRequest{
AutoGroups: nil,
ExpiresIn: expiresIn,
Name: newKeyName,
Type: "reusable",
UsageLimit: 0,
},
expectedStatus: http.StatusOK,
expectedSetupKey: &api.SetupKey{
AutoGroups: []string{},
Ephemeral: false,
Expires: time.Time{},
Id: "",
Key: "",
LastUsed: time.Time{},
Name: newKeyName,
Revoked: false,
State: "valid",
Type: "reusable",
UpdatedAt: time.Now(),
UsageLimit: 0,
UsedTimes: 0,
Valid: true,
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
store, cleanup, err := server.NewTestStoreFromSQL(context.Background(), "testdata/setup_keys.sql", t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
am := server.DefaultAccountManager{
Store: store,
EventStore: &activity.InMemoryEventStore{},
}
apiHandler, err := APIHandler(context.Background(), &am, &geolocation.GeolocationMock{}, &jwtclaims.JwtValidatorMock{}, metrics, AuthCfg{}, server.MocIntegratedValidator{})
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, bytes.NewBuffer(body))
req.Header.Set("Authorization", "Bearer "+"my.dummy.token")
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
got := &api.SetupKey{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
validateCreatedKey(t, tc.expectedSetupKey, got)
key, err := am.GetSetupKey(context.Background(), testAccountId, testUserId, got.Id)
if err != nil {
return
}
validateCreatedKey(t, tc.expectedSetupKey, toResponseBody(key))
})
}
}
func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) {
t.Helper()
if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(expiresIn*time.Second)) {
got.Expires = time.Time{}
expectedKey.Expires = time.Time{}
}
if got.Id == "" {
t.Error("Expected key to have an ID")
}
got.Id = ""
if got.Key == "" {
t.Error("Expected key to have a key")
}
got.Key = ""
if got.UpdatedAt.After(time.Now().Add(-1*time.Minute)) && got.UpdatedAt.Before(time.Now().Add(+1*time.Minute)) {
got.UpdatedAt = time.Time{}
expectedKey.UpdatedAt = time.Time{}
}
assert.Equal(t, expectedKey, got)
}

View File

@@ -7,6 +7,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. // UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
@@ -76,3 +78,44 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
} }
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop(_ context.Context) {
}

View File

@@ -72,13 +72,17 @@ type JSONWebKey struct {
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
// JWTValidator struct to handle token validation and parsing type JWTValidator interface {
type JWTValidator struct { ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
}
// JWTValidatorImpl struct to handle token validation and parsing
type JWTValidatorImpl struct {
options Options options Options
} }
// NewJWTValidator constructor // NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation) keys, err := getPemKeys(ctx, keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -138,13 +142,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
options.UserProperty = "user" options.UserProperty = "user"
} }
return &JWTValidator{ return &JWTValidatorImpl{
options: options, options: options,
}, nil }, nil
} }
// ValidateAndParse validates the token and returns the parsed token // ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { func (m *JWTValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty... // If the token is empty...
if token == "" { if token == "" {
// Check if it was required // Check if it was required
@@ -311,3 +315,15 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0 return 0
} }
type JwtValidatorMock struct{}
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
claimMaps := jwt.MapClaims{}
claimMaps[UserIDClaim] = "testUserId"
claimMaps[AccountIDSuffix] = "testAccountId"
claimMaps[DomainIDSuffix] = "test.com"
claimMaps[DomainCategorySuffix] = "private"
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
return jwtToken, nil
}

View File

@@ -10,13 +10,14 @@ import (
"github.com/eko/gocache/v3/cache" "github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/integration_reference"
@@ -52,7 +53,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
@@ -96,7 +97,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
@@ -118,7 +119,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
@@ -141,7 +142,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn)
@@ -160,7 +161,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
@@ -187,7 +188,7 @@ func TestUser_DeletePAT(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
@@ -224,7 +225,7 @@ func TestUser_GetPAT(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
@@ -261,7 +262,7 @@ func TestUser_GetAllPATs(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID) pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID)
@@ -351,7 +352,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"}) user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"})
@@ -392,7 +393,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
@@ -434,7 +435,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
@@ -459,7 +460,7 @@ func TestUser_InviteNewUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
} }
@@ -559,7 +560,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID) err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
@@ -591,7 +592,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID) err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID)
@@ -639,7 +640,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{}, integratedPeerValidator: MocIntegratedValidator{},
} }
@@ -743,7 +744,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{}, integratedPeerValidator: MocIntegratedValidator{},
} }
@@ -845,7 +846,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
@@ -876,7 +877,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
users, err := am.ListUsers(context.Background(), mockAccountID) users, err := am.ListUsers(context.Background(), mockAccountID)
@@ -958,7 +959,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
users, err := am.ListUsers(context.Background(), mockAccountID) users, err := am.ListUsers(context.Background(), mockAccountID)
@@ -997,7 +998,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
cacheManager: cache.New[[]*idp.UserData]( cacheManager: cache.New[[]*idp.UserData](
@@ -1056,7 +1057,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
@@ -1085,7 +1086,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, EventStore: &activity.InMemoryEventStore{},
} }
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID) users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID)