[management] fix extend call and move config to types (#3575)

This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include:
- Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath).
- Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety.
- Adjusting various test and migration code paths to correctly reference the new configuration and engine types.
This commit is contained in:
Maycon Santos 2025-03-27 13:04:50 +01:00 committed by GitHub
parent fceb3ca392
commit a4f04f5570
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 237 additions and 169 deletions

View File

@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@ -32,7 +33,7 @@ import (
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
t.Helper() t.Helper()
config := &mgmt.Config{} config := &types.Config{}
_, err := util.ReadJson("../testdata/management.json", config) _, err := util.ReadJson("../testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -67,7 +68,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")

View File

@ -1400,15 +1400,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper() t.Helper()
config := &server.Config{ config := &types.Config{
Stuns: []*server.Host{}, Stuns: []*types.Host{},
TURNConfig: &server.TURNConfig{}, TURNConfig: &types.TURNConfig{},
Relay: &server.Relay{ Relay: &types.Relay{
Addresses: []string{"127.0.0.1:1234"}, Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour}, CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222", Secret: "222222222222222222",
}, },
Signal: &server.Host{ Signal: &types.Host{
Proto: "http", Proto: "http",
URI: "localhost:10000", URI: "localhost:10000",
}, },

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
) )
@ -97,10 +98,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
config := &server.Config{ config := &types.Config{
Stuns: []*server.Host{}, Stuns: []*types.Host{},
TURNConfig: &server.TURNConfig{}, TURNConfig: &types.TURNConfig{},
Signal: &server.Host{ Signal: &types.Host{
Proto: "http", Proto: "http",
URI: signalAddr, URI: signalAddr,
}, },

2
go.mod
View File

@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@ -490,8 +490,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 h1:OsLDdb6ekNaCVSyD+omhio2DECfEqLjCA1zo4HrgGqU= github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408 h1:zkMfK8AX4ZEvOypT8xbnQEJwvU6HZ4wiiTkpBFCW504=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0= github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=

View File

@ -50,7 +50,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug") level, _ := log.ParseLevel("debug")
log.SetLevel(level) log.SetLevel(level)
config := &mgmt.Config{} config := &types.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config) _, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -34,7 +34,9 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
@ -70,7 +72,7 @@ var (
mgmtSingleAccModeDomain string mgmtSingleAccModeDomain string
certFile string certFile string
certKey string certKey string
config *server.Config config *types.Config
kaep = keepalive.EnforcementPolicy{ kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second, MinTime: 15 * time.Second,
@ -101,9 +103,9 @@ var (
// detect whether user specified a port // detect whether user specified a port
userPort := cmd.Flag("port").Changed userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, mgmtConfig) config, err = loadMgmtConfig(ctx, types.MgmtConfigPath)
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err)
} }
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed { if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
@ -183,7 +185,7 @@ var (
if config.DataStoreEncryptionKey != key { if config.DataStoreEncryptionKey != key {
log.WithContext(ctx).Infof("update config with activity store key") log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key config.DataStoreEncryptionKey = key
err := updateMgmtConfig(ctx, mgmtConfig, config) err := updateMgmtConfig(ctx, types.MgmtConfigPath, config)
if err != nil { if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err) return fmt.Errorf("failed to write out store encryption key: %s", err)
} }
@ -486,8 +488,8 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
}) })
} }
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
loadedConfig := &server.Config{} loadedConfig := &types.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -522,7 +524,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) { if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
@ -539,7 +541,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope
} }
} }
@ -560,7 +562,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
return loadedConfig, err return loadedConfig, err
} }
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error { func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
return util.DirectWriteJson(ctx, path, config) return util.DirectWriteJson(ctx, path, config)
} }
@ -636,7 +638,7 @@ func handleRebrand(cmd *cobra.Command) error {
} }
} }
} }
if mgmtConfig == defaultMgmtConfig { if types.MgmtConfigPath == defaultMgmtConfig {
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) { if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir) cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir) err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)

View File

@ -7,6 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@ -19,7 +20,6 @@ const (
var ( var (
dnsDomain string dnsDomain string
mgmtDataDir string mgmtDataDir string
mgmtConfig string
logLevel string logLevel string
logFile string logFile string
disableMetrics bool disableMetrics bool
@ -56,7 +56,7 @@ func init() {
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain) mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.") mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")

View File

@ -2793,15 +2793,15 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
}) })
} }
type TB interface { //type TB interface {
Cleanup(func()) // Cleanup(func())
Helper() // Helper()
TempDir() string // TempDir() string
Errorf(format string, args ...interface{}) // Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{}) // Fatalf(format string, args ...interface{})
} //}
func createManager(t TB) (*DefaultAccountManager, error) { func createManager(t testing.TB) (*DefaultAccountManager, error) {
t.Helper() t.Helper()
store, err := createStore(t) store, err := createStore(t)
@ -2836,7 +2836,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return manager, nil return manager, nil
} }
func createStore(t TB) (store.Store, error) { func createStore(t testing.TB) (store.Store, error) {
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)

View File

@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@ -40,7 +41,7 @@ type GRPCServer struct {
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *Config config *types.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
@ -51,7 +52,7 @@ type GRPCServer struct {
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer( func NewServer(
ctx context.Context, ctx context.Context,
config *Config, config *types.Config,
accountManager account.Manager, accountManager account.Manager,
settingsManager settings.Manager, settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager, peersUpdateManager *PeersUpdateManager,
@ -530,24 +531,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil return userID, nil
} }
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol {
switch configProto { switch configProto {
case UDP: case types.UDP:
return proto.HostConfig_UDP return proto.HostConfig_UDP
case DTLS: case types.DTLS:
return proto.HostConfig_DTLS return proto.HostConfig_DTLS
case HTTP: case types.HTTP:
return proto.HostConfig_HTTP return proto.HostConfig_HTTP
case HTTPS: case types.HTTPS:
return proto.HostConfig_HTTPS return proto.HostConfig_HTTPS
case TCP: case types.TCP:
return proto.HostConfig_TCP return proto.HostConfig_TCP
default: default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto)) panic(fmt.Errorf("unexpected config protocol type %v", configProto))
} }
} }
func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil { if config == nil {
return nil return nil
} }
@ -610,8 +611,6 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token,
Relay: relayCfg, Relay: relayCfg,
} }
integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings)
return nbConfig return nbConfig
} }
@ -626,9 +625,8 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns
} }
} }
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
@ -638,6 +636,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
Checks: toProtocolChecks(ctx, checks), Checks: toProtocolChecks(ctx, checks),
} }
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
response.NetbirdConfig = nbConfig
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
@ -754,7 +756,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) { if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available") return nil, status.Error(codes.NotFound, "no device authorization flow information available")
} }

View File

@ -93,21 +93,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) { func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*Host{{ Stuns: []*types.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &TURNConfig{ TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*Host{{ Turns: []*types.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &Host{ Signal: &types.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },
@ -330,7 +330,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
inputFlow *DeviceAuthorizationFlow inputFlow *types.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string expectedErrMSG string
@ -345,9 +345,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}, },
{ {
name: "Testing Invalid Device Flow Provider Config", name: "Testing Invalid Device Flow Provider Config",
inputFlow: &DeviceAuthorizationFlow{ inputFlow: &types.DeviceAuthorizationFlow{
Provider: "NoNe", Provider: "NoNe",
ProviderConfig: ProviderConfig{ ProviderConfig: types.ProviderConfig{
ClientID: "test", ClientID: "test",
}, },
}, },
@ -356,9 +356,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}, },
{ {
name: "Testing Full Device Flow Config", name: "Testing Full Device Flow Config",
inputFlow: &DeviceAuthorizationFlow{ inputFlow: &types.DeviceAuthorizationFlow{
Provider: "hosted", Provider: "hosted",
ProviderConfig: ProviderConfig{ ProviderConfig: types.ProviderConfig{
ClientID: "test", ClientID: "test",
}, },
}, },
@ -379,7 +379,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{ mgmtServer := &GRPCServer{
wgKey: testingServerKey, wgKey: testingServerKey,
config: &Config{ config: &types.Config{
DeviceAuthorizationFlow: testCase.inputFlow, DeviceAuthorizationFlow: testCase.inputFlow,
}, },
} }
@ -410,7 +410,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
} }
} }
func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
@ -506,21 +506,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip() t.Skip()
dir := t.TempDir() dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*Host{{ Stuns: []*types.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &TURNConfig{ TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*Host{{ Turns: []*types.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &Host{ Signal: &types.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },
@ -678,21 +678,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper() t.Helper()
dir := t.TempDir() dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{
Stuns: []*Host{{ Stuns: []*types.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &TURNConfig{ TURNConfig: &types.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*Host{{ Turns: []*types.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &Host{ Signal: &types.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },

View File

@ -58,7 +58,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err) t.Fatalf("failed to create temp directory: %v", err)
} }
config := &server.Config{} config := &types.Config{}
_, err = util.ReadJson("testdata/management.json", config) _, err = util.ReadJson("testdata/management.json", config)
if err != nil { if err != nil {
t.Fatalf("failed to read management.json: %v", err) t.Fatalf("failed to read management.json: %v", err)
@ -156,7 +156,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer( func startServer(
t *testing.T, t *testing.T,
config *server.Config, config *types.Config,
dataDir string, dataDir string,
testFile string, testFile string,
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {
@ -300,6 +300,10 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
Protocol: mgmtProto.HostConfig_UDP, Protocol: mgmtProto.HostConfig_UDP,
} }
expectedRelayHost := &mgmtProto.RelayConfig{
Urls: []string{"rel://test.com:3535"},
}
assert.NotNil(t, resp.NetbirdConfig) assert.NotNil(t, resp.NetbirdConfig)
assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig) assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig)
assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig) assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig)
@ -307,6 +311,8 @@ func TestSyncNewPeerConfiguration(t *testing.T) {
actualTURN := resp.NetbirdConfig.Turns[0] actualTURN := resp.NetbirdConfig.Turns[0]
assert.Greater(t, len(actualTURN.User), 0) assert.Greater(t, len(actualTURN.User), 0)
assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost) assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost)
assert.Equal(t, len(resp.NetbirdConfig.Relay.Urls), 1)
assert.Equal(t, resp.NetbirdConfig.Relay.Urls, expectedRelayHost.Urls)
assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0) assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0)
} }

View File

@ -15,7 +15,6 @@ import (
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version" nbversion "github.com/netbirdio/netbird/version"
) )
@ -49,7 +48,7 @@ type properties map[string]interface{}
// DataSource metric data source // DataSource metric data source
type DataSource interface { type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() store.Engine GetStoreEngine() types.Engine
} }
// ConnManager peer connection manager that holds state for current active connections // ConnManager peer connection manager that holds state for current active connections

View File

@ -10,7 +10,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -205,8 +204,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
} }
// GetStoreEngine returns FileStoreEngine // GetStoreEngine returns FileStoreEngine
func (mockDatasource) GetStoreEngine() store.Engine { func (mockDatasource) GetStoreEngine() types.Engine {
return store.FileStoreEngine return types.FileStoreEngine
} }
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
@ -304,7 +303,7 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
} }
if properties["store_engine"] != store.FileStoreEngine { if properties["store_engine"] != types.FileStoreEngine {
t.Errorf("expected JsonFile, got %s", properties["store_engine"]) t.Errorf("expected JsonFile, got %s", properties["store_engine"])
} }

View File

@ -1213,7 +1213,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
update := toSyncResponse(ctx, &Config{}, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer) }(peer)
} }
@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return return
} }
update := toSyncResponse(ctx, &Config{}, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
} }

View File

@ -723,7 +723,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
} }
} }
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper() b.Helper()
manager, err := createManager(b) manager, err := createManager(b)
@ -998,6 +998,53 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
} }
} }
func TestUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 1},
{"Medium", 500, 1},
{"Large", 1000, 1},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
t.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
}
})
}
}
func TestToSyncResponse(t *testing.T) { func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24") _, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil { if err != nil {
@ -1008,16 +1055,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
config := &Config{ config := &types.Config{
Signal: &Host{ Signal: &types.Host{
Proto: "https", Proto: "https",
URI: "signal.uri", URI: "signal.uri",
Username: "", Username: "",
Password: "", Password: "",
}, },
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}},
TURNConfig: &TURNConfig{ TURNConfig: &types.TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}},
}, },
} }
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{

View File

@ -260,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error {
} }
// GetStoreEngine returns FileStoreEngine // GetStoreEngine returns FileStoreEngine
func (s *FileStore) GetStoreEngine() Engine { func (s *FileStore) GetStoreEngine() types.Engine {
return FileStoreEngine return types.FileStoreEngine
} }

View File

@ -55,7 +55,7 @@ type SqlStore struct {
globalAccountLock sync.Mutex globalAccountLock sync.Mutex
metrics telemetry.AppMetrics metrics telemetry.AppMetrics
installationPK int installationPK int
storeEngine Engine storeEngine types.Engine
} }
type installation struct { type installation struct {
@ -66,7 +66,7 @@ type installation struct {
type migrationFunc func(*gorm.DB) error type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance. // NewSqlStore creates a new SqlStore instance.
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB() sql, err := db.DB()
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,7 +77,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
conns = runtime.NumCPU() conns = runtime.NumCPU()
} }
if storeEngine == SqliteStoreEngine { if storeEngine == types.SqliteStoreEngine {
if err == nil { if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
} }
@ -105,7 +105,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t
} }
func GetKeyQueryCondition(s *SqlStore) string { func GetKeyQueryCondition(s *SqlStore) string {
if s.storeEngine == MysqlStoreEngine { if s.storeEngine == types.MysqlStoreEngine {
return mysqlKeyQueryCondition return mysqlKeyQueryCondition
} }
return keyQueryCondition return keyQueryCondition
@ -970,7 +970,7 @@ func (s *SqlStore) Close(_ context.Context) error {
} }
// GetStoreEngine returns underlying store engine // GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() Engine { func (s *SqlStore) GetStoreEngine() types.Engine {
return s.storeEngine return s.storeEngine
} }
@ -988,7 +988,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
return nil, err return nil, err
} }
return NewSqlStore(ctx, db, SqliteStoreEngine, metrics) return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics)
} }
// NewPostgresqlStore creates a new Postgres store. // NewPostgresqlStore creates a new Postgres store.
@ -998,7 +998,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return nil, err return nil, err
} }
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics)
} }
// NewMysqlStore creates a new MySQL store. // NewMysqlStore creates a new MySQL store.
@ -1008,7 +1008,7 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
return nil, err return nil, err
} }
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics) return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics)
} }
func getGormConfig() *gorm.Config { func getGormConfig() *gorm.Config {
@ -1517,9 +1517,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations)
switch s.storeEngine { switch s.storeEngine {
case PostgresStoreEngine: case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC") query = query.Order("json_array_length(peers::json) DESC")
case MysqlStoreEngine: case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default: default:
query = query.Order("json_array_length(peers) DESC") query = query.Order("json_array_length(peers) DESC")

View File

@ -297,7 +297,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
} }
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
@ -628,7 +628,7 @@ func TestMigrate(t *testing.T) {
} }
// TODO: figure out why this fails on postgres // TODO: figure out why this fails on postgres
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
@ -737,7 +737,7 @@ func TestPostgresql_NewStore(t *testing.T) {
t.Skip("skip CI tests on darwin and windows") t.Skip("skip CI tests on darwin and windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
@ -752,7 +752,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows") t.Skip("skip CI tests on darwin and windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
@ -825,7 +825,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Skip("skip CI tests on darwin and windows") t.Skip("skip CI tests on darwin and windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
@ -901,7 +901,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("skip CI tests on darwin and windows") t.Skip("skip CI tests on darwin and windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
@ -921,7 +921,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("skip CI tests on darwin and windows") t.Skip("skip CI tests on darwin and windows")
} }
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
assert.NoError(t, err) assert.NoError(t, err)
@ -935,7 +935,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
} }
func TestSqlite_GetTakenIPs(t *testing.T) { func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
defer cleanup() defer cleanup()
if err != nil { if err != nil {
@ -980,7 +980,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
} }
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil { if err != nil {
return return
@ -1022,7 +1022,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
} }
func TestSqlite_GetAccountNetwork(t *testing.T) { func TestSqlite_GetAccountNetwork(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
if err != nil { if err != nil {
@ -1045,7 +1045,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
} }
func TestSqlite_GetSetupKeyBySecret(t *testing.T) { func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
if err != nil { if err != nil {
@ -1070,7 +1070,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
} }
func TestSqlite_incrementSetupKeyUsage(t *testing.T) { func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
if err != nil { if err != nil {
@ -1106,7 +1106,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
} }
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
if err != nil { if err != nil {
@ -1211,7 +1211,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
} }
func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
require.NoError(t, err) require.NoError(t, err)
@ -1227,7 +1227,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
} }
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
require.NoError(t, err) require.NoError(t, err)

View File

@ -164,7 +164,7 @@ type Store interface {
Close(ctx context.Context) error Close(ctx context.Context) error
// GetStoreEngine should return Engine of the current store implementation. // GetStoreEngine should return Engine of the current store implementation.
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() Engine GetStoreEngine() types.Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
@ -187,44 +187,37 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
} }
type Engine string
const ( const (
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
PostgresStoreEngine Engine = "postgres"
MysqlStoreEngine Engine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
) )
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() Engine { func getStoreEngineFromEnv() types.Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok { if !ok {
return "" return ""
} }
value := Engine(strings.ToLower(kind)) value := types.Engine(strings.ToLower(kind))
if slices.Contains(supportedEngines, value) { if slices.Contains(supportedEngines, value) {
return value return value
} }
return SqliteStoreEngine return types.SqliteStoreEngine
} }
// getStoreEngine determines the store engine to use. // getStoreEngine determines the store engine to use.
// If no engine is specified, it attempts to retrieve it from the environment. // If no engine is specified, it attempts to retrieve it from the environment.
// If still not specified, it defaults to using SQLite. // If still not specified, it defaults to using SQLite.
// Additionally, it handles the migration from a JSON store file to SQLite if applicable. // Additionally, it handles the migration from a JSON store file to SQLite if applicable.
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine {
if kind == "" { if kind == "" {
kind = getStoreEngineFromEnv() kind = getStoreEngineFromEnv()
if kind == "" { if kind == "" {
kind = SqliteStoreEngine kind = types.SqliteStoreEngine
// Migrate if it is the first run with a JSON file existing and no SQLite file present // Migrate if it is the first run with a JSON file existing and no SQLite file present
jsonStoreFile := filepath.Join(dataDir, storeFileName) jsonStoreFile := filepath.Join(dataDir, storeFileName)
@ -236,7 +229,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
// Attempt to migrate from JSON store to SQLite // Attempt to migrate from JSON store to SQLite
if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil {
log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err)
kind = FileStoreEngine kind = types.FileStoreEngine
} }
} }
} }
@ -246,7 +239,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
} }
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
kind = getStoreEngine(ctx, dataDir, kind) kind = getStoreEngine(ctx, dataDir, kind)
if err := checkFileStoreEngine(kind, dataDir); err != nil { if err := checkFileStoreEngine(kind, dataDir); err != nil {
@ -254,13 +247,13 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
} }
switch kind { switch kind {
case SqliteStoreEngine: case types.SqliteStoreEngine:
log.WithContext(ctx).Info("using SQLite store engine") log.WithContext(ctx).Info("using SQLite store engine")
return NewSqliteStore(ctx, dataDir, metrics) return NewSqliteStore(ctx, dataDir, metrics)
case PostgresStoreEngine: case types.PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine") log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics) return newPostgresStore(ctx, metrics)
case MysqlStoreEngine: case types.MysqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine") log.WithContext(ctx).Info("using MySQL store engine")
return newMysqlStore(ctx, metrics) return newMysqlStore(ctx, metrics)
default: default:
@ -268,12 +261,12 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr
} }
} }
func checkFileStoreEngine(kind Engine, dataDir string) error { func checkFileStoreEngine(kind types.Engine, dataDir string) error {
if kind == FileStoreEngine { if kind == types.FileStoreEngine {
storeFile := filepath.Join(dataDir, storeFileName) storeFile := filepath.Join(dataDir, storeFileName)
if util.FileExists(storeFile) { if util.FileExists(storeFile) {
return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+ return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+
"https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine) "https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine)
} }
} }
return nil return nil
@ -326,7 +319,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) { func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
kind := getStoreEngineFromEnv() kind := getStoreEngineFromEnv()
if kind == "" { if kind == "" {
kind = SqliteStoreEngine kind = types.SqliteStoreEngine
} }
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
@ -348,7 +341,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
} }
} }
store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil) store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to create test store: %v", err) return nil, nil, fmt.Errorf("failed to create test store: %v", err)
} }
@ -394,13 +387,13 @@ func addAllGroupToAccount(ctx context.Context, store Store) error {
return nil return nil
} }
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) {
var cleanup func() var cleanup func()
var err error var err error
switch kind { switch kind {
case PostgresStoreEngine: case types.PostgresStoreEngine:
store, cleanup, err = newReusedPostgresStore(ctx, store, kind) store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
case MysqlStoreEngine: case types.MysqlStoreEngine:
store, cleanup, err = newReusedMysqlStore(ctx, store, kind) store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
default: default:
cleanup = func() { cleanup = func() {
@ -419,7 +412,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store
return store, closeConnection, nil return store, closeConnection, nil
} }
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) { func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
var err error var err error
_, err = testutil.CreatePostgresTestContainer() _, err = testutil.CreatePostgresTestContainer()
@ -451,7 +444,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (
return store, cleanup, nil return store, cleanup, nil
} }
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) { func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
var err error var err error
_, err = testutil.CreateMysqlTestContainer() _, err = testutil.CreateMysqlTestContainer()
@ -483,7 +476,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*Sq
return store, cleanup, nil return store, cleanup, nil
} }
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) { func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil { if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
@ -493,9 +486,9 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err
var err error var err error
cleanup := func() { cleanup := func() {
switch engine { switch engine {
case PostgresStoreEngine: case types.PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
case MysqlStoreEngine: case types.MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName) // err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
} }

View File

@ -19,6 +19,11 @@
"CredentialsTTL": "1h", "CredentialsTTL": "1h",
"Secret": "c29tZV9wYXNzd29yZA==", "Secret": "c29tZV9wYXNzd29yZA==",
"TimeBasedCredentials": true "TimeBasedCredentials": true
},
"Relay":{
"Addresses":["rel://test.com:3535"],
"CredentialsTTL":"2h",
"Secret":"netbird"
}, },
"Signal": { "Signal": {
"Proto": "http", "Proto": "http",

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
auth "github.com/netbirdio/netbird/relay/auth/hmac" auth "github.com/netbirdio/netbird/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
@ -32,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct { type TimeBasedAuthSecretsManager struct {
mux sync.Mutex mux sync.Mutex
turnCfg *TURNConfig turnCfg *types.TURNConfig
relayCfg *Relay relayCfg *types.Relay
turnHmacToken *auth.TimedHMAC turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
@ -44,7 +45,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager { func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{ mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager, updateManager: updateManager,
turnCfg: turnCfg, turnCfg: turnCfg,
@ -221,7 +222,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
} }
} }
m.extendNetbirdConfig(ctx, accountID, update) m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
@ -245,17 +246,17 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
}, },
} }
m.extendNetbirdConfig(ctx, accountID, update) m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
} }
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, accountID string, update *proto.SyncResponse) { func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID) extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
} }
integrationsConfig.ExtendNetBirdConfig(update.NetbirdConfig, extraSettings) integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
} }

View File

@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
var TurnTestHost = &Host{ var TurnTestHost = &types.Host{
Proto: UDP, Proto: types.UDP,
URI: "turn:turn.netbird.io:77777", URI: "turn:turn.netbird.io:77777",
Username: "username", Username: "username",
Password: "", Password: "",
@ -31,7 +31,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret" secret := "some_secret"
peersManager := NewPeersUpdateManager(nil) peersManager := NewPeersUpdateManager(nil)
rc := &Relay{ rc := &types.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@ -41,10 +41,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager)
@ -81,7 +81,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer" peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer) updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &Relay{ rc := &types.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@ -92,10 +92,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager)
@ -184,7 +184,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil) peersManager := NewPeersUpdateManager(nil)
peer := "some_peer" peer := "some_peer"
rc := &Relay{ rc := &types.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@ -194,10 +194,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager)

View File

@ -1,10 +1,9 @@
package server package types
import ( import (
"net/netip" "net/netip"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -30,6 +29,8 @@ const (
DefaultDeviceAuthFlowScope string = "openid" DefaultDeviceAuthFlowScope string = "openid"
) )
var MgmtConfigPath string
// Config of the Management service // Config of the Management service
type Config struct { type Config struct {
Stuns []*Host Stuns []*Host
@ -76,6 +77,7 @@ type TURNConfig struct {
Turns []*Host Turns []*Host
} }
// Relay configuration type
type Relay struct { type Relay struct {
Addresses []string Addresses []string
CredentialsTTL util.Duration CredentialsTTL util.Duration
@ -156,7 +158,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration // StoreConfig contains Store configuration
type StoreConfig struct { type StoreConfig struct {
Engine store.Engine Engine Engine
} }
// ReverseProxy contains reverse proxy configuration in front of management. // ReverseProxy contains reverse proxy configuration in front of management.

View File

@ -0,0 +1,10 @@
package types
type Engine string
const (
PostgresStoreEngine Engine = "postgres"
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
MysqlStoreEngine Engine = "mysql"
)