Merge branch 'main' into feature/mysql-support

This commit is contained in:
bcmmbaga 2025-01-02 17:41:54 +03:00
commit 2028cbd481
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
23 changed files with 1919 additions and 150 deletions

View File

@ -42,7 +42,7 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
httpapi "github.com/netbirdio/netbird/management/server/http"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -281,7 +281,7 @@ var (
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

View File

@ -161,7 +161,7 @@ type DefaultAccountManager struct {
externalCacheManager ExternalCacheManager
ctx context.Context
eventStore activity.Store
geo *geolocation.Geolocation
geo geolocation.Geolocation
requestBuffer *AccountRequestBuffer
@ -244,7 +244,7 @@ func BuildManager(
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,

View File

@ -28,7 +28,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -39,47 +38,6 @@ import (
"github.com/netbirdio/netbird/route"
)
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop(_ context.Context) {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
t.Helper()
peer := &nbpeer.Peer{

View File

@ -14,7 +14,14 @@ import (
log "github.com/sirupsen/logrus"
)
type Geolocation struct {
type Geolocation interface {
Lookup(ip net.IP) (*Record, error)
GetAllCountries() ([]Country, error)
GetCitiesByCountry(countryISOCode string) ([]City, error)
Stop() error
}
type geolocationImpl struct {
mmdbPath string
mux sync.RWMutex
db *maxminddb.Reader
@ -54,7 +61,7 @@ const (
geonamesdbPattern = "geonames_*.db"
)
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) {
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) {
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
if err != nil {
@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol
return nil, err
}
geo := &Geolocation{
geo := &geolocationImpl{
mmdbPath: mmdbPath,
mux: sync.RWMutex{},
db: db,
@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
return db, nil
}
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock()
defer gl.mux.RUnlock()
@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
}
// GetAllCountries retrieves a list of all countries.
func (gl *Geolocation) GetAllCountries() ([]Country, error) {
func (gl *geolocationImpl) GetAllCountries() ([]Country, error) {
allCountries, err := gl.locationDB.GetAllCountries()
if err != nil {
return nil, err
@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) {
}
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) {
func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) {
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
if err != nil {
return nil, err
@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error)
return cities, nil
}
func (gl *Geolocation) Stop() error {
func (gl *geolocationImpl) Stop() error {
close(gl.stopCh)
if gl.db != nil {
if err := gl.db.Close(); err != nil {
@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin
}
return nil
}
type Mock struct{}
func (g *Mock) Lookup(ip net.IP) (*Record, error) {
return &Record{}, nil
}
func (g *Mock) GetAllCountries() ([]Country, error) {
return []Country{}, nil
}
func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) {
return []City{}, nil
}
func (g *Mock) Stop() error {
return nil
}

View File

@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) {
db, err := openDB(filename)
assert.NoError(t, err)
geo := &Geolocation{
geo := &geolocationImpl{
mux: sync.RWMutex{},
db: db,
stopCh: make(chan struct{}),

View File

@ -38,7 +38,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager
config *Config
secretsManager SecretsManager
jwtValidator *jwtclaims.JWTValidator
jwtValidator jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
@ -61,7 +61,7 @@ func NewServer(
return nil, err
}
var jwtValidator *jwtclaims.JWTValidator
var jwtValidator jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(

View File

@ -35,15 +35,8 @@ import (
const apiPrefix = "/api"
type apiHandler struct {
Router *mux.Router
AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation
AuthCfg configs.AuthCfg
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{
Router: router,
AccountManager: accountManager,
geolocationManager: LocationManager,
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
accounts.AddEndpoints(api.AccountManager, authCfg, router)
peers.AddEndpoints(api.AccountManager, authCfg, router)
users.AddEndpoints(api.AccountManager, authCfg, router)
setup_keys.AddEndpoints(api.AccountManager, authCfg, router)
policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router)
groups.AddEndpoints(api.AccountManager, authCfg, router)
routes.AddEndpoints(api.AccountManager, authCfg, router)
dns.AddEndpoints(api.AccountManager, authCfg, router)
events.AddEndpoints(api.AccountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router)
accounts.AddEndpoints(accountManager, authCfg, router)
peers.AddEndpoints(accountManager, authCfg, router)
users.AddEndpoints(accountManager, authCfg, router)
setup_keys.AddEndpoints(accountManager, authCfg, router)
policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
groups.AddEndpoints(accountManager, authCfg, router)
routes.AddEndpoints(accountManager, authCfg, router)
dns.AddEndpoints(accountManager, authCfg, router)
events.AddEndpoints(accountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)
return rootRouter, nil
}

View File

@ -22,18 +22,18 @@ var (
// geolocationsHandler is a handler that returns locations.
type geolocationsHandler struct {
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
return &geolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,

View File

@ -23,7 +23,7 @@ type handler struct {
claimsExtractor *jwtclaims.ClaimsExtractor
}
func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
policiesHandler := newHandler(accountManager, authCfg)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")

View File

@ -19,11 +19,11 @@ import (
// postureChecksHandler is a handler that returns posture checks of the account.
type postureChecksHandler struct {
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag
}
// newPostureChecksHandler creates a new PostureChecks handler
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
return &postureChecksHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,

View File

@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Geolocation{},
geolocationManager: &geolocation.Mock{},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{

View File

@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
apiSetupKeys := toResponseBody(setupKey)
apiSetupKeys := ToResponseBody(setupKey)
// for the creation we need to send the plain key
apiSetupKeys.Key = setupKey.Key
@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
apiSetupKeys = append(apiSetupKeys, ToResponseBody(key))
}
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key))
err := json.NewEncoder(w).Encode(ToResponseBody(key))
if err != nil {
util.WriteError(ctx, err, w)
return
}
}
func toResponseBody(key *types.SetupKey) *api.SetupKey {
func ToResponseBody(key *types.SetupKey) *api.SetupKey {
var state string
switch {
case key.IsExpired():

View File

@ -26,7 +26,6 @@ const (
newSetupKeyName = "New Setup Key"
updatedSetupKeyName = "KKKey"
notFoundSetupKeyID = "notFoundSetupKeyID"
testAccountID = "test_id"
)
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
@ -81,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe
return jwtclaims.AuthorizationClaims{
UserId: user.Id,
Domain: "hotmail.com",
AccountId: testAccountID,
AccountId: "testAccountId",
}
}),
),
@ -102,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) {
updatedDefaultSetupKey.Name = updatedSetupKeyName
updatedDefaultSetupKey.Revoked = true
expectedNewKey := toResponseBody(newSetupKey)
expectedNewKey := ToResponseBody(newSetupKey)
expectedNewKey.Key = plainKey
tt := []struct {
name string
@ -120,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) {
requestPath: "/api/setup-keys",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)},
expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)},
},
{
name: "Get Existing Setup Key",
@ -128,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) {
requestPath: "/api/setup-keys/" + existingSetupKeyID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(defaultSetupKey),
expectedSetupKey: ToResponseBody(defaultSetupKey),
},
{
name: "Get Not Existing Setup Key",
@ -159,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) {
))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
expectedSetupKey: ToResponseBody(updatedDefaultSetupKey),
},
{
name: "Delete Setup Key",
@ -228,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) {
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
t.Helper()
// this comparison is done manually because when converting to JSON dates formatted differently
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
// assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
assert.Equal(t, got.Name, expected.Name)

View File

@ -0,0 +1,226 @@
package benchmarks
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
// Map to store peers, groups, users, and setupKeys by name
var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{
"Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5},
"Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100},
"Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000},
"Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000},
"Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000},
"Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000},
"Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000},
"Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000},
}
func BenchmarkCreateSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
requestBody := api.CreateSetupKeyRequest{
AutoGroups: []string{testing_tools.TestGroupId},
ExpiresIn: testing_tools.ExpiresIn,
Name: testing_tools.NewKeyName + strconv.Itoa(i),
Type: "reusable",
UsageLimit: 0,
}
// the time marshal will be recorded as well but for our use case that is ok
body, err := json.Marshal(requestBody)
assert.NoError(b, err)
req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkUpdateSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
groupId := testing_tools.TestGroupId
if i%2 == 0 {
groupId = testing_tools.NewGroupId
}
requestBody := api.SetupKeyRequest{
AutoGroups: []string{groupId},
Revoked: false,
}
// the time marshal will be recorded as well but for our use case that is ok
body, err := json.Marshal(requestBody)
assert.NoError(b, err)
req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkGetOneSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkGetAllSetupKeys(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
"Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40},
"Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkDeleteSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0);
INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1);

View File

@ -0,0 +1,307 @@
package testing_tools
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
const (
TestAccountId = "testAccountId"
TestPeerId = "testPeerId"
TestGroupId = "testGroupId"
TestKeyId = "testKeyId"
TestUserId = "testUserId"
TestAdminId = "testAdminId"
TestOwnerId = "testOwnerId"
TestServiceUserId = "testServiceUserId"
TestServiceAdminId = "testServiceAdminId"
BlockedUserId = "blockedUserId"
OtherUserId = "otherUserId"
InvalidToken = "invalidToken"
NewKeyName = "newKey"
NewGroupId = "newGroupId"
ExpiresIn = 3600
RevokedKeyId = "revokedKeyId"
ExpiredKeyId = "expiredKeyId"
ExistingKeyName = "existingKey"
)
type TB interface {
Cleanup(func())
Helper()
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
TempDir() string
}
// BenchmarkCase defines a single benchmark test case
type BenchmarkCase struct {
Peers int
Groups int
Users int
SetupKeys int
}
// PerformanceMetrics holds the performance expectations
type PerformanceMetrics struct {
MinMsPerOpLocal float64
MaxMsPerOpLocal float64
MinMsPerOpCICD float64
MaxMsPerOpCICD float64
}
func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
done := make(chan struct{})
go func() {
if expectedPeerUpdate != nil {
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
} else {
peerShouldNotReceiveUpdate(t, updMsg)
}
close(done)
}()
geoMock := &geolocation.Mock{}
validatorMock := server.MocIntegratedValidator{}
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
t.Errorf("Unexpected message received: %+v", msg)
case <-time.After(500 * time.Millisecond):
return
}
}
func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
assert.Equal(t, expected, msg)
case <-time.After(500 * time.Millisecond):
t.Errorf("Timed out waiting for update message")
}
}
func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request {
t.Helper()
req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody))
req.Header.Set("Authorization", "Bearer "+user)
return req
}
func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) {
t.Helper()
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if !expectResponse {
return nil, false
}
if status := recorder.Code; status != expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v, content: %s",
status, expectedStatus, string(content))
}
return content, expectedStatus == http.StatusOK
}
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
b.Helper()
ctx := context.Background()
account, err := am.GetAccount(ctx, TestAccountId)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
// Create peers
for i := 0; i < peers; i++ {
peerKey, _ := wgtypes.GeneratePrivateKey()
peer := &nbpeer.Peer{
ID: fmt.Sprintf("oldpeer-%d", i),
DNSLabel: fmt.Sprintf("oldpeer-%d", i),
Key: peerKey.PublicKey().String(),
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
Status: &nbpeer.PeerStatus{},
UserID: TestUserId,
}
account.Peers[peer.ID] = peer
}
// Create users
for i := 0; i < users; i++ {
user := &types.User{
Id: fmt.Sprintf("olduser-%d", i),
AccountID: account.Id,
Role: types.UserRoleUser,
}
account.Users[user.Id] = user
}
for i := 0; i < setupKeys; i++ {
key := &types.SetupKey{
Id: fmt.Sprintf("oldkey-%d", i),
AccountID: account.Id,
AutoGroups: []string{"someGroupID"},
ExpiresAt: time.Now().Add(ExpiresIn * time.Second),
Name: NewKeyName + strconv.Itoa(i),
Type: "reusable",
UsageLimit: 0,
}
account.SetupKeys[key.Id] = key
}
// Create groups and policies
account.Policies = make([]*types.Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{
ID: groupID,
Name: fmt.Sprintf("Group %d", i),
}
for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
account.Groups[groupID] = group
// Create a policy for this group
policy := &types.Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true,
Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
},
},
}
account.Policies = append(account.Policies, policy)
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
}
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
b.Fatalf("Failed to save account: %v", err)
}
}
func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) {
b.Helper()
if recorder.Code != http.StatusOK {
b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code)
}
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := perfMetrics.MinMsPerOpLocal
maxExpected := perfMetrics.MaxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = perfMetrics.MinMsPerOpCICD
maxExpected = perfMetrics.MaxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected)
}
}

View File

@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
@ -78,3 +79,45 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
}
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
// just a dummy
}
func (MocIntegratedValidator) Stop(_ context.Context) {
// just a dummy
}

View File

@ -72,15 +72,19 @@ type JSONWebKey struct {
X5c []string `json:"x5c"`
}
// JWTValidator struct to handle token validation and parsing
type JWTValidator struct {
type JWTValidator interface {
ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
}
// jwtValidatorImpl struct to handle token validation and parsing
type jwtValidatorImpl struct {
options Options
}
var keyNotFound = errors.New("unable to find appropriate key")
// NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation)
if err != nil {
return nil, err
@ -146,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
options.UserProperty = "user"
}
return &JWTValidator{
return &jwtValidatorImpl{
options: options,
}, nil
}
// ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
@ -318,3 +322,28 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0
}
type JwtValidatorMock struct{}
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
claimMaps := jwt.MapClaims{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
claimMaps[UserIDClaim] = token
claimMaps[AccountIDSuffix] = "testAccountId"
claimMaps[DomainIDSuffix] = "test.com"
claimMaps[DomainCategorySuffix] = "private"
case "otherUserId":
claimMaps[UserIDClaim] = "otherUserId"
claimMaps[AccountIDSuffix] = "otherAccountId"
claimMaps[DomainIDSuffix] = "other.com"
claimMaps[DomainCategorySuffix] = "private"
case "invalidToken":
return nil, errors.New("invalid token")
}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
return jwtToken, nil
}

View File

@ -21,13 +21,10 @@ import (
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
@ -448,43 +445,6 @@ var _ = Describe("Management service", func() {
})
})
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for p := range peers {
validatedPeers[p] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop(_ context.Context) {}
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
defer GinkgoRecover()
@ -547,7 +507,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
log.Fatalf("failed creating metrics: %v", err)
}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}

View File

@ -32,6 +32,9 @@ type managerImpl struct {
routersManager routers.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{
store: store,
@ -185,3 +188,27 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
return nil
}
func NewManagerMock() Manager {
return &mockManager{}
}
func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
return []*types.Network{}, nil
}
func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
return network, nil
}
func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
return &types.Network{}, nil
}
func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
return network, nil
}
func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
return nil
}

View File

@ -34,6 +34,9 @@ type managerImpl struct {
accountManager s.AccountManager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{
store: store,
@ -381,3 +384,39 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
return eventsToStore, nil
}
func NewManagerMock() Manager {
return &mockManager{}
}
func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
return []*types.NetworkResource{}, nil
}
func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
return []*types.NetworkResource{}, nil
}
func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
return map[string][]string{}, nil
}
func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
return &types.NetworkResource{}, nil
}
func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
return &types.NetworkResource{}, nil
}
func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
return &types.NetworkResource{}, nil
}
func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
return nil
}
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
return []func(){}, nil
}

View File

@ -75,7 +75,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
}
setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
@ -132,7 +132,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
return err
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
}
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)