[management] fix extend call and move config to types (#3575)

This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include:
- Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath).
- Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety.
- Adjusting various test and migration code paths to correctly reference the new configuration and engine types.
This commit is contained in:
Maycon Santos 2025-03-27 13:04:50 +01:00 committed by GitHub
parent fceb3ca392
commit a4f04f5570
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 237 additions and 169 deletions

View File

@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
@ -32,7 +33,7 @@ import (
func startTestingServices(t *testing.T) string {
t.Helper()
config := &mgmt.Config{}
config := &types.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
t.Fatal(err)
@ -67,7 +68,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")

View File

@ -1400,15 +1400,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Relay: &server.Relay{
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Relay: &types.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &server.Host{
Signal: &types.Host{
Proto: "http",
URI: "localhost:10000",
},

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
@ -97,10 +98,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Helper()
dataDir := t.TempDir()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Signal: &server.Host{
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Signal: &types.Host{
Proto: "http",
URI: signalAddr,
},

2
go.mod
View File

@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939
github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@ -490,8 +490,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 h1:OsLDdb6ekNaCVSyD+omhio2DECfEqLjCA1zo4HrgGqU=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408 h1:zkMfK8AX4ZEvOypT8xbnQEJwvU6HZ4wiiTkpBFCW504=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=

View File

@ -50,7 +50,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
config := &mgmt.Config{}
config := &types.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)

View File

@ -34,7 +34,9 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
@ -70,7 +72,7 @@ var (
mgmtSingleAccModeDomain string
certFile string
certKey string
config *server.Config
config *types.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@ -101,9 +103,9 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, mgmtConfig)
config, err = loadMgmtConfig(ctx, types.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err)
}
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
@ -183,7 +185,7 @@ var (
if config.DataStoreEncryptionKey != key {
log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, mgmtConfig, config)
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
@ -486,8 +488,8 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
})
}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
loadedConfig := &types.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil {
return nil, err
@ -522,7 +524,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
@ -539,7 +541,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope
}
}
@ -560,7 +562,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
return loadedConfig, err
}
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
@ -636,7 +638,7 @@ func handleRebrand(cmd *cobra.Command) error {
}
}
}
if mgmtConfig == defaultMgmtConfig {
if types.MgmtConfigPath == defaultMgmtConfig {
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)

View File

@ -7,6 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/version"
)
@ -19,7 +20,6 @@ const (
var (
dnsDomain string
mgmtDataDir string
mgmtConfig string
logLevel string
logFile string
disableMetrics bool
@ -56,7 +56,7 @@ func init() {
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")

View File

@ -2793,15 +2793,15 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
type TB interface {
Cleanup(func())
Helper()
TempDir() string
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
}
//type TB interface {
// Cleanup(func())
// Helper()
// TempDir() string
// Errorf(format string, args ...interface{})
// Fatalf(format string, args ...interface{})
//}
func createManager(t TB) (*DefaultAccountManager, error) {
func createManager(t testing.TB) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
@ -2836,7 +2836,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t TB) (store.Store, error) {
func createStore(t testing.TB) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)

View File

@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account"
@ -40,7 +41,7 @@ type GRPCServer struct {
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *Config
config *types.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
@ -51,7 +52,7 @@ type GRPCServer struct {
// NewServer creates a new Management server
func NewServer(
ctx context.Context,
config *Config,
config *types.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager,
@ -530,24 +531,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil
}
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol {
switch configProto {
case UDP:
case types.UDP:
return proto.HostConfig_UDP
case DTLS:
case types.DTLS:
return proto.HostConfig_DTLS
case HTTP:
case types.HTTP:
return proto.HostConfig_HTTP
case HTTPS:
case types.HTTPS:
return proto.HostConfig_HTTPS
case TCP:
case types.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
@ -610,8 +611,6 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token,
Relay: relayCfg,
}
integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings)
return nbConfig
}
@ -626,9 +625,8 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns
}
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
response := &proto.SyncResponse{
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
@ -638,6 +636,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
response.NetbirdConfig = nbConfig
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
@ -754,7 +756,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}

View File

@ -93,21 +93,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@ -330,7 +330,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct {
name string
inputFlow *DeviceAuthorizationFlow
inputFlow *types.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
@ -345,9 +345,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &DeviceAuthorizationFlow{
inputFlow: &types.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: ProviderConfig{
ProviderConfig: types.ProviderConfig{
ClientID: "test",
},
},
@ -356,9 +356,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
},
{
name: "Testing Full Device Flow Config",
inputFlow: &DeviceAuthorizationFlow{
inputFlow: &types.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: ProviderConfig{
ProviderConfig: types.ProviderConfig{
ClientID: "test",
},
},
@ -379,7 +379,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &Config{
config: &types.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
@ -410,7 +410,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@ -506,21 +506,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip()
dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},
@ -678,21 +678,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper()
dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{
Stuns: []*Host{{
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*types.Host{{
Proto: "udp",
URI: "stun:stun.netbird.io:3468",
}},
TURNConfig: &TURNConfig{
TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{},
Secret: "whatever",
Turns: []*Host{{
Turns: []*types.Host{{
Proto: "udp",
URI: "turn:stun.netbird.io:3468",
}},
},
Signal: &Host{
Signal: &types.Host{
Proto: "http",
URI: "signal.netbird.io:10000",
},

View File

@ -58,7 +58,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err)
}
config := &server.Config{}
config := &types.Config{}
_, err = util.ReadJson("testdata/management.json", config)
if err != nil {
t.Fatalf("failed to read management.json: %v", err)
@ -156,7 +156,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer(
t *testing.T,
config *server.Config,
config *types.Config,
dataDir string,
testFile string,
) (*grpc.Server, net.Listener) {
@ -300,6 +300,10 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
Protocol: mgmtProto.HostConfig_UDP,
}
expectedRelayHost := &mgmtProto.RelayConfig{
Urls: []string{"rel://test.com:3535"},
}
assert.NotNil(t, resp.NetbirdConfig)
assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig)
assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig)
@ -307,6 +311,8 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
actualTURN := resp.NetbirdConfig.Turns[0]
assert.Greater(t, len(actualTURN.User), 0)
assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost)
assert.Equal(t, len(resp.NetbirdConfig.Relay.Urls), 1)
assert.Equal(t, resp.NetbirdConfig.Relay.Urls, expectedRelayHost.Urls)
assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0)
}

View File

@ -15,7 +15,6 @@ import (
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version"
)
@ -49,7 +48,7 @@ type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() store.Engine
GetStoreEngine() types.Engine
}
// ConnManager peer connection manager that holds state for current active connections

View File

@ -10,7 +10,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -205,8 +204,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}
// GetStoreEngine returns FileStoreEngine
func (mockDatasource) GetStoreEngine() store.Engine {
return store.FileStoreEngine
func (mockDatasource) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
@ -304,7 +303,7 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
}
if properties["store_engine"] != store.FileStoreEngine {
if properties["store_engine"] != types.FileStoreEngine {
t.Errorf("expected JsonFile, got %s", properties["store_engine"])
}

View File

@ -1213,7 +1213,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return
}
update := toSyncResponse(ctx, &Config{}, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
update := toSyncResponse(ctx, &Config{}, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}

View File

@ -723,7 +723,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
@ -998,6 +998,53 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 1},
{"Medium", 500, 1},
{"Large", 1000, 1},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
t.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
}
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
@ -1008,16 +1055,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
config := &types.Config{
Signal: &types.Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}},
TURNConfig: &types.TURNConfig{
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{

View File

@ -260,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error {
}
// GetStoreEngine returns FileStoreEngine
func (s *FileStore) GetStoreEngine() Engine {
return FileStoreEngine
func (s *FileStore) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}

View File

@ -55,7 +55,7 @@ type SqlStore struct {
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine Engine
storeEngine types.Engine
}
type installation struct {
@ -66,7 +66,7 @@ type installation struct {
type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance.
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB()
if err != nil {
return nil, err
@ -77,7 +77,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
conns = runtime.NumCPU()
}
if storeEngine == SqliteStoreEngine {
if storeEngine == types.SqliteStoreEngine {
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
}
@ -105,7 +105,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
}
func GetKeyQueryCondition(s *SqlStore) string {
if s.storeEngine == MysqlStoreEngine {
if s.storeEngine == types.MysqlStoreEngine {
return mysqlKeyQueryCondition
}
return keyQueryCondition
@ -970,7 +970,7 @@ func (s *SqlStore) Close(_ context.Context) error {
}
// GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() Engine {
func (s *SqlStore) GetStoreEngine() types.Engine {
return s.storeEngine
}
@ -988,7 +988,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
return nil, err
}
return NewSqlStore(ctx, db, SqliteStoreEngine, metrics)
return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics)
}
// NewPostgresqlStore creates a new Postgres store.
@ -998,7 +998,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return nil, err
}
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics)
}
// NewMysqlStore creates a new MySQL store.
@ -1008,7 +1008,7 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
return nil, err
}
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics)
}
func getGormConfig() *gorm.Config {
@ -1517,9 +1517,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
switch s.storeEngine {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case MysqlStoreEngine:
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")

View File

@ -297,7 +297,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -628,7 +628,7 @@ func TestMigrate(t *testing.T) {
}
// TODO: figure out why this fails on postgres
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
@ -737,7 +737,7 @@ func TestPostgresql_NewStore(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -752,7 +752,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -825,7 +825,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -901,7 +901,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -921,7 +921,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -935,7 +935,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
}
func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
defer cleanup()
if err != nil {
@ -980,7 +980,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil {
return
@ -1022,7 +1022,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@ -1045,7 +1045,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@ -1070,7 +1070,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@ -1106,7 +1106,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
}
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@ -1211,7 +1211,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
}
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1227,7 +1227,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
}
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)

View File

@ -164,7 +164,7 @@ type Store interface {
Close(ctx context.Context) error
// GetStoreEngine should return Engine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() Engine
GetStoreEngine() types.Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
@ -187,44 +187,37 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
}
type Engine string
const (
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
PostgresStoreEngine Engine = "postgres"
MysqlStoreEngine Engine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() Engine {
func getStoreEngineFromEnv() types.Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok {
return ""
}
value := Engine(strings.ToLower(kind))
value := types.Engine(strings.ToLower(kind))
if slices.Contains(supportedEngines, value) {
return value
}
return SqliteStoreEngine
return types.SqliteStoreEngine
}
// getStoreEngine determines the store engine to use.
// If no engine is specified, it attempts to retrieve it from the environment.
// If still not specified, it defaults to using SQLite.
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine {
if kind == "" {
kind = getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
kind = types.SqliteStoreEngine
// Migrate if it is the first run with a JSON file existing and no SQLite file present
jsonStoreFile := filepath.Join(dataDir, storeFileName)
@ -236,7 +229,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
// Attempt to migrate from JSON store to SQLite
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
kind = FileStoreEngine
kind = types.FileStoreEngine
}
}
}
@ -246,7 +239,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
}
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
kind = getStoreEngine(ctx, dataDir, kind)
if err := checkFileStoreEngine(kind, dataDir); err != nil {
@ -254,13 +247,13 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
}
switch kind {
case SqliteStoreEngine:
case types.SqliteStoreEngine:
log.WithContext(ctx).Info("using SQLite store engine")
return NewSqliteStore(ctx, dataDir, metrics)
case PostgresStoreEngine:
case types.PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics)
case MysqlStoreEngine:
case types.MysqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine")
return newMysqlStore(ctx, metrics)
default:
@ -268,12 +261,12 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
}
}
func checkFileStoreEngine(kind Engine, dataDir string) error {
if kind == FileStoreEngine {
func checkFileStoreEngine(kind types.Engine, dataDir string) error {
if kind == types.FileStoreEngine {
storeFile := filepath.Join(dataDir, storeFileName)
if util.FileExists(storeFile) {
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine)
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine)
}
}
return nil
@ -326,7 +319,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
kind := getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
kind = types.SqliteStoreEngine
}
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
@ -348,7 +341,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
}
}
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil)
store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
}
@ -394,13 +387,13 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
return nil
}
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) {
var cleanup func()
var err error
switch kind {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
case MysqlStoreEngine:
case types.MysqlStoreEngine:
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
default:
cleanup = func() {
@ -419,7 +412,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store
return store, closeConnection, nil
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreatePostgresTestContainer()
@ -451,7 +444,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (
return store, cleanup, nil
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreateMysqlTestContainer()
@ -483,7 +476,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*Sq
return store, cleanup, nil
}
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
@ -493,9 +486,9 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err
var err error
cleanup := func() {
switch engine {
case PostgresStoreEngine:
case types.PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
case MysqlStoreEngine:
case types.MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
}

View File

@ -19,6 +19,11 @@
"CredentialsTTL": "1h",
"Secret": "c29tZV9wYXNzd29yZA==",
"TimeBasedCredentials": true
},
"Relay":{
"Addresses":["rel://test.com:3535"],
"CredentialsTTL":"2h",
"Secret":"netbird"
},
"Signal": {
"Proto": "http",

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
@ -32,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *TURNConfig
relayCfg *Relay
turnCfg *types.TURNConfig
relayCfg *types.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
@ -44,7 +45,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@ -221,7 +222,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
}
}
m.extendNetbirdConfig(ctx, accountID, update)
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
@ -245,17 +246,17 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
},
}
m.extendNetbirdConfig(ctx, accountID, update)
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, accountID string, update *proto.SyncResponse) {
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
integrationsConfig.ExtendNetBirdConfig(update.NetbirdConfig, extraSettings)
integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
}

View File

@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &Host{
Proto: UDP,
var TurnTestHost = &types.Host{
Proto: types.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
@ -31,7 +31,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@ -41,10 +41,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
@ -81,7 +81,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@ -92,10 +92,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
@ -184,7 +184,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &Relay{
rc := &types.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
@ -194,10 +194,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)

View File

@ -1,10 +1,9 @@
package server
package types
import (
"net/netip"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
@ -30,6 +29,8 @@ const (
DefaultDeviceAuthFlowScope string = "openid"
)
var MgmtConfigPath string
// Config of the Management service
type Config struct {
Stuns []*Host
@ -76,6 +77,7 @@ type TURNConfig struct {
Turns []*Host
}
// Relay configuration type
type Relay struct {
Addresses []string
CredentialsTTL util.Duration
@ -156,7 +158,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration
type StoreConfig struct {
Engine store.Engine
Engine Engine
}
// ReverseProxy contains reverse proxy configuration in front of management.

View File

@ -0,0 +1,10 @@
package types
type Engine string
const (
PostgresStoreEngine Engine = "postgres"
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
MysqlStoreEngine Engine = "mysql"
)