mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 22:10:56 +01:00
refactor jwtValidator and geo db to interfaces + first component test for setup-keys
This commit is contained in:
parent
1bbabf70b0
commit
832e168869
@ -264,7 +264,7 @@ var (
|
||||
KeysLocation: config.HttpConfig.AuthKeysLocation,
|
||||
}
|
||||
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
|
@ -166,8 +166,8 @@ type DefaultAccountManager struct {
|
||||
cacheManager cache.CacheInterface[[]*idp.UserData]
|
||||
externalCacheManager ExternalCacheManager
|
||||
ctx context.Context
|
||||
eventStore activity.Store
|
||||
geo *geolocation.Geolocation
|
||||
EventStore activity.Store
|
||||
geo geolocation.Geolocation
|
||||
|
||||
requestBuffer *AccountRequestBuffer
|
||||
|
||||
@ -1041,7 +1041,7 @@ func BuildManager(
|
||||
singleAccountModeDomain string,
|
||||
dnsDomain string,
|
||||
eventStore activity.Store,
|
||||
geo *geolocation.Geolocation,
|
||||
geo geolocation.Geolocation,
|
||||
userDeleteFromIDPEnabled bool,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
metrics telemetry.AppMetrics,
|
||||
@ -1055,7 +1055,7 @@ func BuildManager(
|
||||
cacheMux: sync.Mutex{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
EventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
peerInactivityExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
|
@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
|
||||
}
|
||||
|
||||
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
|
||||
events, err := am.EventStore.Get(ctx, accountID, 0, 10000, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI
|
||||
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
|
||||
go func() {
|
||||
_, err := am.eventStore.Save(ctx, &activity.Event{
|
||||
_, err := am.EventStore.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activityID,
|
||||
InitiatorID: initiatorID,
|
||||
|
@ -14,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
|
||||
accountID string, count int) {
|
||||
t.Helper()
|
||||
for i := 0; i < count; i++ {
|
||||
_, err := manager.eventStore.Save(context.Background(), &activity.Event{
|
||||
_, err := manager.EventStore.Save(context.Background(), &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: typ,
|
||||
InitiatorID: initiatorID,
|
||||
@ -41,7 +41,7 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
return
|
||||
}
|
||||
assert.Len(t, events, 0)
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
_ = manager.EventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
|
||||
t.Run("get events", func(t *testing.T) {
|
||||
@ -52,7 +52,7 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Len(t, events, 10)
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
_ = manager.EventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
|
||||
t.Run("get events without duplicates", func(t *testing.T) {
|
||||
@ -62,6 +62,6 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
return
|
||||
}
|
||||
assert.Len(t, events, 1)
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
_ = manager.EventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
}
|
||||
|
@ -14,7 +14,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Geolocation struct {
|
||||
type Geolocation interface {
|
||||
Lookup(ip net.IP) (*Record, error)
|
||||
GetAllCountries() ([]Country, error)
|
||||
GetCitiesByCountry(countryISOCode string) ([]City, error)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type GeolocationImpl struct {
|
||||
mmdbPath string
|
||||
mux sync.RWMutex
|
||||
db *maxminddb.Reader
|
||||
@ -54,7 +61,7 @@ const (
|
||||
geonamesdbPattern = "geonames_*.db"
|
||||
)
|
||||
|
||||
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) {
|
||||
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) {
|
||||
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
|
||||
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
|
||||
if err != nil {
|
||||
@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol
|
||||
return nil, err
|
||||
}
|
||||
|
||||
geo := &Geolocation{
|
||||
geo := &GeolocationImpl{
|
||||
mmdbPath: mmdbPath,
|
||||
mux: sync.RWMutex{},
|
||||
db: db,
|
||||
@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
|
||||
func (gl *GeolocationImpl) Lookup(ip net.IP) (*Record, error) {
|
||||
gl.mux.RLock()
|
||||
defer gl.mux.RUnlock()
|
||||
|
||||
@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
|
||||
}
|
||||
|
||||
// GetAllCountries retrieves a list of all countries.
|
||||
func (gl *Geolocation) GetAllCountries() ([]Country, error) {
|
||||
func (gl *GeolocationImpl) GetAllCountries() ([]Country, error) {
|
||||
allCountries, err := gl.locationDB.GetAllCountries()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) {
|
||||
}
|
||||
|
||||
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
|
||||
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
||||
func (gl *GeolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
||||
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error)
|
||||
return cities, nil
|
||||
}
|
||||
|
||||
func (gl *Geolocation) Stop() error {
|
||||
func (gl *GeolocationImpl) Stop() error {
|
||||
close(gl.stopCh)
|
||||
if gl.db != nil {
|
||||
if err := gl.db.Close(); err != nil {
|
||||
@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeolocationMock struct{}
|
||||
|
||||
func (g *GeolocationMock) Lookup(ip net.IP) (*Record, error) {
|
||||
return &Record{}, nil
|
||||
}
|
||||
|
||||
func (g *GeolocationMock) GetAllCountries() ([]Country, error) {
|
||||
return []Country{}, nil
|
||||
}
|
||||
|
||||
func (g *GeolocationMock) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
||||
return []City{}, nil
|
||||
}
|
||||
|
||||
func (g *GeolocationMock) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ type GRPCServer struct {
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
secretsManager SecretsManager
|
||||
jwtValidator *jwtclaims.JWTValidator
|
||||
jwtValidator jwtclaims.JWTValidator
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
@ -57,7 +57,7 @@ func NewServer(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var jwtValidator *jwtclaims.JWTValidator
|
||||
var jwtValidator jwtclaims.JWTValidator
|
||||
|
||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||
|
@ -21,12 +21,12 @@ var (
|
||||
// GeolocationsHandler is a handler that returns locations.
|
||||
type GeolocationsHandler struct {
|
||||
accountManager server.AccountManager
|
||||
geolocationManager *geolocation.Geolocation
|
||||
geolocationManager geolocation.Geolocation
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewGeolocationsHandlerHandler creates a new Geolocations handler
|
||||
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
|
||||
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
|
||||
return &GeolocationsHandler{
|
||||
accountManager: accountManager,
|
||||
geolocationManager: geolocationManager,
|
||||
|
@ -31,7 +31,7 @@ type AuthCfg struct {
|
||||
type apiHandler struct {
|
||||
Router *mux.Router
|
||||
AccountManager s.AccountManager
|
||||
geolocationManager *geolocation.Geolocation
|
||||
geolocationManager geolocation.Geolocation
|
||||
AuthCfg AuthCfg
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ type emptyObject struct {
|
||||
}
|
||||
|
||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
|
@ -18,12 +18,12 @@ import (
|
||||
// PostureChecksHandler is a handler that returns posture checks of the account.
|
||||
type PostureChecksHandler struct {
|
||||
accountManager server.AccountManager
|
||||
geolocationManager *geolocation.Geolocation
|
||||
geolocationManager geolocation.Geolocation
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewPostureChecksHandler creates a new PostureChecks handler
|
||||
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler {
|
||||
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler {
|
||||
return &PostureChecksHandler{
|
||||
accountManager: accountManager,
|
||||
geolocationManager: geolocationManager,
|
||||
|
@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
},
|
||||
geolocationManager: &geolocation.Geolocation{},
|
||||
geolocationManager: &geolocation.GeolocationImpl{},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
|
159
management/server/http/setupkeys_component_test.go
Normal file
159
management/server/http/setupkeys_component_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
//go:build component
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountId = "testUserId"
|
||||
testUserId = "testAccountId"
|
||||
|
||||
newKeyName = "newKey"
|
||||
expiresIn = 3600
|
||||
|
||||
existingKeyName = "existingKey"
|
||||
)
|
||||
|
||||
func Test_SetupKeys_Create_Success(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedSetupKey *api.SetupKey
|
||||
requestBody *api.CreateSetupKeyRequest
|
||||
requestType string
|
||||
requestPath string
|
||||
}{
|
||||
{
|
||||
name: "Create Setup Key",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/setup-keys",
|
||||
requestBody: &api.CreateSetupKeyRequest{
|
||||
AutoGroups: nil,
|
||||
ExpiresIn: expiresIn,
|
||||
Name: newKeyName,
|
||||
Type: "reusable",
|
||||
UsageLimit: 0,
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSetupKey: &api.SetupKey{
|
||||
AutoGroups: []string{},
|
||||
Ephemeral: false,
|
||||
Expires: time.Time{},
|
||||
Id: "",
|
||||
Key: "",
|
||||
LastUsed: time.Time{},
|
||||
Name: newKeyName,
|
||||
Revoked: false,
|
||||
State: "valid",
|
||||
Type: "reusable",
|
||||
UpdatedAt: time.Now(),
|
||||
UsageLimit: 0,
|
||||
UsedTimes: 0,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
store, cleanup, err := server.NewTestStoreFromSQL(context.Background(), "testdata/setup_keys.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test store: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
am := server.DefaultAccountManager{
|
||||
Store: store,
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
apiHandler, err := APIHandler(context.Background(), &am, &geolocation.GeolocationMock{}, &jwtclaims.JwtValidatorMock{}, metrics, AuthCfg{}, server.MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, bytes.NewBuffer(body))
|
||||
req.Header.Set("Authorization", "Bearer "+"my.dummy.token")
|
||||
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
got := &api.SetupKey{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
validateCreatedKey(t, tc.expectedSetupKey, got)
|
||||
|
||||
key, err := am.GetSetupKey(context.Background(), testAccountId, testUserId, got.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
validateCreatedKey(t, tc.expectedSetupKey, toResponseBody(key))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) {
|
||||
t.Helper()
|
||||
|
||||
if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(expiresIn*time.Second)) {
|
||||
got.Expires = time.Time{}
|
||||
expectedKey.Expires = time.Time{}
|
||||
}
|
||||
|
||||
if got.Id == "" {
|
||||
t.Error("Expected key to have an ID")
|
||||
}
|
||||
got.Id = ""
|
||||
|
||||
if got.Key == "" {
|
||||
t.Error("Expected key to have a key")
|
||||
}
|
||||
got.Key = ""
|
||||
|
||||
if got.UpdatedAt.After(time.Now().Add(-1*time.Minute)) && got.UpdatedAt.Before(time.Now().Add(+1*time.Minute)) {
|
||||
got.UpdatedAt = time.Time{}
|
||||
expectedKey.UpdatedAt = time.Time{}
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedKey, got)
|
||||
}
|
@ -7,6 +7,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
|
||||
@ -76,3 +78,44 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
|
||||
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||
}
|
||||
|
||||
type MocIntegratedValidator struct {
|
||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
if a.ValidatePeerFunc != nil {
|
||||
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||
}
|
||||
return update, false, nil
|
||||
}
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for _, peer := range peers {
|
||||
validatedPeers[peer.ID] = struct{}{}
|
||||
}
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
||||
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
||||
}
|
||||
|
@ -72,13 +72,17 @@ type JSONWebKey struct {
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
// JWTValidator struct to handle token validation and parsing
|
||||
type JWTValidator struct {
|
||||
type JWTValidator interface {
|
||||
ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
|
||||
}
|
||||
|
||||
// JWTValidatorImpl struct to handle token validation and parsing
|
||||
type JWTValidatorImpl struct {
|
||||
options Options
|
||||
}
|
||||
|
||||
// NewJWTValidator constructor
|
||||
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
|
||||
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
|
||||
keys, err := getPemKeys(ctx, keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -138,13 +142,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
|
||||
options.UserProperty = "user"
|
||||
}
|
||||
|
||||
return &JWTValidator{
|
||||
return &JWTValidatorImpl{
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
func (m *JWTValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// Check if it was required
|
||||
@ -311,3 +315,15 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
type JwtValidatorMock struct{}
|
||||
|
||||
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[UserIDClaim] = "testUserId"
|
||||
claimMaps[AccountIDSuffix] = "testAccountId"
|
||||
claimMaps[DomainIDSuffix] = "test.com"
|
||||
claimMaps[DomainCategorySuffix] = "private"
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
@ -10,13 +10,14 @@ import (
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
@ -52,7 +53,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
@ -96,7 +97,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
|
||||
@ -118,7 +119,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn)
|
||||
@ -141,7 +142,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn)
|
||||
@ -160,7 +161,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
|
||||
@ -187,7 +188,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
|
||||
@ -224,7 +225,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1)
|
||||
@ -261,7 +262,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID)
|
||||
@ -351,7 +352,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"})
|
||||
@ -392,7 +393,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
|
||||
@ -434,7 +435,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
|
||||
@ -459,7 +460,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
}
|
||||
|
||||
@ -559,7 +560,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
|
||||
@ -591,7 +592,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID)
|
||||
@ -639,7 +640,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
integratedPeerValidator: MocIntegratedValidator{},
|
||||
}
|
||||
|
||||
@ -743,7 +744,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
integratedPeerValidator: MocIntegratedValidator{},
|
||||
}
|
||||
|
||||
@ -845,7 +846,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
@ -876,7 +877,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.ListUsers(context.Background(), mockAccountID)
|
||||
@ -958,7 +959,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.ListUsers(context.Background(), mockAccountID)
|
||||
@ -997,7 +998,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
idpManager: &idp.GoogleWorkspaceManager{}, // empty manager
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
cacheManager: cache.New[[]*idp.UserData](
|
||||
@ -1056,7 +1057,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||
@ -1085,7 +1086,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
EventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID)
|
||||
|
Loading…
Reference in New Issue
Block a user