From a4f04f557073c241a3b297d6f59be5ce6112ec73 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 27 Mar 2025 13:04:50 +0100 Subject: [PATCH] [management] fix extend call and move config to types (#3575) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include: - Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath). - Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety. - Adjusting various test and migration code paths to correctly reference the new configuration and engine types. --- client/cmd/testutil_test.go | 5 +- client/internal/engine_test.go | 10 ++-- client/server/server_test.go | 9 +-- go.mod | 2 +- go.sum | 4 +- management/client/client_test.go | 2 +- management/cmd/management.go | 22 ++++---- management/cmd/root.go | 4 +- management/server/account_test.go | 18 +++--- management/server/grpcserver.go | 32 ++++++----- management/server/management_proto_test.go | 44 +++++++-------- management/server/management_test.go | 10 +++- management/server/metrics/selfhosted.go | 3 +- management/server/metrics/selfhosted_test.go | 7 +-- management/server/peer.go | 4 +- management/server/peer_test.go | 59 ++++++++++++++++++-- management/server/store/file_store.go | 4 +- management/server/store/sql_store.go | 20 +++---- management/server/store/sql_store_test.go | 30 +++++----- management/server/store/store.go | 57 +++++++++---------- management/server/testdata/management.json | 5 ++ management/server/token_mgr.go | 15 ++--- management/server/token_mgr_test.go | 22 ++++---- management/server/{ => types}/config.go | 8 ++- management/server/types/store.go | 10 ++++ 25 files changed, 237 insertions(+), 169 deletions(-) rename management/server/{ => types}/config.go (98%) create mode 100644 management/server/types/store.go diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 22b982f61..bcec2472f 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" @@ -32,7 +33,7 @@ import ( func startTestingServices(t *testing.T) string { t.Helper() - config := &mgmt.Config{} + config := &types.Config{} _, err := util.ReadJson("../testdata/management.json", config) if err != nil { t.Fatal(err) @@ -67,7 +68,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { return s, lis } -func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { +func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", ":0") diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 56fef43e1..72e7c6d1c 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1400,15 +1400,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { t.Helper() - config := &server.Config{ - Stuns: []*server.Host{}, - TURNConfig: &server.TURNConfig{}, - Relay: &server.Relay{ + config := &types.Config{ + Stuns: []*types.Host{}, + TURNConfig: &types.TURNConfig{}, + Relay: &types.Relay{ Addresses: []string{"127.0.0.1:1234"}, CredentialsTTL: util.Duration{Duration: time.Hour}, Secret: "222222222222222222", }, - Signal: &server.Host{ + Signal: &types.Host{ Proto: "http", URI: "localhost:10000", }, diff --git a/client/server/server_test.go b/client/server/server_test.go index 1dd5fa3c9..5083a29f2 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" ) @@ -97,10 +98,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve t.Helper() dataDir := t.TempDir() - config := &server.Config{ - Stuns: []*server.Host{}, - TURNConfig: &server.TURNConfig{}, - Signal: &server.Host{ + config := &types.Config{ + Stuns: []*types.Host{}, + TURNConfig: &types.TURNConfig{}, + Signal: &types.Host{ Proto: "http", URI: signalAddr, }, diff --git a/go.mod b/go.mod index f1c514f9f..83804c265 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index fb0189709..eeffe2bc4 100644 --- a/go.sum +++ b/go.sum @@ -490,8 +490,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939 h1:OsLDdb6ekNaCVSyD+omhio2DECfEqLjCA1zo4HrgGqU= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250320152138-69b93e4ef939/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408 h1:zkMfK8AX4ZEvOypT8xbnQEJwvU6HZ4wiiTkpBFCW504= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250325155416-f73a616e5408/go.mod h1:3LvBPnW+i06K9fQr1SYwsbhvnxQHtIC8vvO4PjLmmy0= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/client/client_test.go b/management/client/client_test.go index 65237754c..24204688d 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -50,7 +50,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { level, _ := log.ParseLevel("debug") log.SetLevel(level) - config := &mgmt.Config{} + config := &types.Config{} _, err := util.ReadJson("../server/testdata/management.json", config) if err != nil { t.Fatal(err) diff --git a/management/cmd/management.go b/management/cmd/management.go index 1b2216932..f0b8d5d12 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -34,7 +34,9 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" @@ -70,7 +72,7 @@ var ( mgmtSingleAccModeDomain string certFile string certKey string - config *server.Config + config *types.Config kaep = keepalive.EnforcementPolicy{ MinTime: 15 * time.Second, @@ -101,9 +103,9 @@ var ( // detect whether user specified a port userPort := cmd.Flag("port").Changed - config, err = loadMgmtConfig(ctx, mgmtConfig) + config, err = loadMgmtConfig(ctx, types.MgmtConfigPath) if err != nil { - return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) + return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err) } if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed { @@ -183,7 +185,7 @@ var ( if config.DataStoreEncryptionKey != key { log.WithContext(ctx).Infof("update config with activity store key") config.DataStoreEncryptionKey = key - err := updateMgmtConfig(ctx, mgmtConfig, config) + err := updateMgmtConfig(ctx, types.MgmtConfigPath, config) if err != nil { return fmt.Errorf("failed to write out store encryption key: %s", err) } @@ -486,8 +488,8 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle }) } -func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { - loadedConfig := &server.Config{} +func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) { + loadedConfig := &types.Config{} _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err @@ -522,7 +524,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI - if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) { + if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) { log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint @@ -539,7 +541,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { - loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope + loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope } } @@ -560,7 +562,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, return loadedConfig, err } -func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error { +func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error { return util.DirectWriteJson(ctx, path, config) } @@ -636,7 +638,7 @@ func handleRebrand(cmd *cobra.Command) error { } } } - if mgmtConfig == defaultMgmtConfig { + if types.MgmtConfigPath == defaultMgmtConfig { if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) { cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir) err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir) diff --git a/management/cmd/root.go b/management/cmd/root.go index 86155a956..31271a8c6 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/version" ) @@ -19,7 +20,6 @@ const ( var ( dnsDomain string mgmtDataDir string - mgmtConfig string logLevel string logFile string disableMetrics bool @@ -56,7 +56,7 @@ func init() { mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") - mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") + mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain) mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.") diff --git a/management/server/account_test.go b/management/server/account_test.go index 1cfcf127c..715cfab84 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2793,15 +2793,15 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { }) } -type TB interface { - Cleanup(func()) - Helper() - TempDir() string - Errorf(format string, args ...interface{}) - Fatalf(format string, args ...interface{}) -} +//type TB interface { +// Cleanup(func()) +// Helper() +// TempDir() string +// Errorf(format string, args ...interface{}) +// Fatalf(format string, args ...interface{}) +//} -func createManager(t TB) (*DefaultAccountManager, error) { +func createManager(t testing.TB) (*DefaultAccountManager, error) { t.Helper() store, err := createStore(t) @@ -2836,7 +2836,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { return manager, nil } -func createStore(t TB) (store.Store, error) { +func createStore(t testing.TB) (store.Store, error) { t.Helper() dataDir := t.TempDir() store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 49b7b4a33..5e59b9df1 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/status" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/account" @@ -40,7 +41,7 @@ type GRPCServer struct { wgKey wgtypes.Key proto.UnimplementedManagementServiceServer peersUpdateManager *PeersUpdateManager - config *Config + config *types.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager @@ -51,7 +52,7 @@ type GRPCServer struct { // NewServer creates a new Management server func NewServer( ctx context.Context, - config *Config, + config *types.Config, accountManager account.Manager, settingsManager settings.Manager, peersUpdateManager *PeersUpdateManager, @@ -530,24 +531,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR return userID, nil } -func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { +func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol { switch configProto { - case UDP: + case types.UDP: return proto.HostConfig_UDP - case DTLS: + case types.DTLS: return proto.HostConfig_DTLS - case HTTP: + case types.HTTP: return proto.HostConfig_HTTP - case HTTPS: + case types.HTTPS: return proto.HostConfig_HTTPS - case TCP: + case types.TCP: return proto.HostConfig_TCP default: panic(fmt.Errorf("unexpected config protocol type %v", configProto)) } } -func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { +func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { if config == nil { return nil } @@ -610,8 +611,6 @@ func toNetbirdConfig(config *Config, turnCredentials *Token, relayToken *Token, Relay: relayCfg, } - integrationsConfig.ExtendNetBirdConfig(nbConfig, extraSettings) - return nbConfig } @@ -626,10 +625,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dns } } -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnabled bool, extraSettings *types.ExtraSettings) *proto.SyncResponse { response := &proto.SyncResponse{ - NetbirdConfig: toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings), - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnabled), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -638,6 +636,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn Checks: toProtocolChecks(ctx, checks), } + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings) + response.NetbirdConfig = nbConfig + response.NetworkMap.PeerConfig = response.PeerConfig allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) @@ -754,7 +756,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. return nil, status.Error(codes.InvalidArgument, errMSG) } - if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) { + if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) { return nil, status.Error(codes.NotFound, "no device authorization flow information available") } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 04fd88359..d4933dd94 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -93,21 +93,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ - Stuns: []*Host{{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ + Stuns: []*types.Host{{ Proto: "udp", URI: "stun:stun.netbird.io:3468", }}, - TURNConfig: &TURNConfig{ + TURNConfig: &types.TURNConfig{ TimeBasedCredentials: false, CredentialsTTL: util.Duration{}, Secret: "whatever", - Turns: []*Host{{ + Turns: []*types.Host{{ Proto: "udp", URI: "turn:stun.netbird.io:3468", }}, }, - Signal: &Host{ + Signal: &types.Host{ Proto: "http", URI: "signal.netbird.io:10000", }, @@ -330,7 +330,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { testCases := []struct { name string - inputFlow *DeviceAuthorizationFlow + inputFlow *types.DeviceAuthorizationFlow expectedFlow *mgmtProto.DeviceAuthorizationFlow expectedErrFunc require.ErrorAssertionFunc expectedErrMSG string @@ -345,9 +345,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { }, { name: "Testing Invalid Device Flow Provider Config", - inputFlow: &DeviceAuthorizationFlow{ + inputFlow: &types.DeviceAuthorizationFlow{ Provider: "NoNe", - ProviderConfig: ProviderConfig{ + ProviderConfig: types.ProviderConfig{ ClientID: "test", }, }, @@ -356,9 +356,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { }, { name: "Testing Full Device Flow Config", - inputFlow: &DeviceAuthorizationFlow{ + inputFlow: &types.DeviceAuthorizationFlow{ Provider: "hosted", - ProviderConfig: ProviderConfig{ + ProviderConfig: types.ProviderConfig{ ClientID: "test", }, }, @@ -379,7 +379,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { mgmtServer := &GRPCServer{ wgKey: testingServerKey, - config: &Config{ + config: &types.Config{ DeviceAuthorizationFlow: testCase.inputFlow, }, } @@ -410,7 +410,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { +func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -506,21 +506,21 @@ func testSyncStatusRace(t *testing.T) { t.Skip() dir := t.TempDir() - mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ - Stuns: []*Host{{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ + Stuns: []*types.Host{{ Proto: "udp", URI: "stun:stun.netbird.io:3468", }}, - TURNConfig: &TURNConfig{ + TURNConfig: &types.TURNConfig{ TimeBasedCredentials: false, CredentialsTTL: util.Duration{}, Secret: "whatever", - Turns: []*Host{{ + Turns: []*types.Host{{ Proto: "udp", URI: "turn:stun.netbird.io:3468", }}, }, - Signal: &Host{ + Signal: &types.Host{ Proto: "http", URI: "signal.netbird.io:10000", }, @@ -678,21 +678,21 @@ func Test_LoginPerformance(t *testing.T) { t.Helper() dir := t.TempDir() - mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ - Stuns: []*Host{{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ + Stuns: []*types.Host{{ Proto: "udp", URI: "stun:stun.netbird.io:3468", }}, - TURNConfig: &TURNConfig{ + TURNConfig: &types.TURNConfig{ TimeBasedCredentials: false, CredentialsTTL: util.Duration{}, Secret: "whatever", - Turns: []*Host{{ + Turns: []*types.Host{{ Proto: "udp", URI: "turn:stun.netbird.io:3468", }}, }, - Signal: &Host{ + Signal: &types.Host{ Proto: "http", URI: "signal.netbird.io:10000", }, diff --git a/management/server/management_test.go b/management/server/management_test.go index 9cad3ab9d..689a05623 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -58,7 +58,7 @@ func setupTest(t *testing.T) *testSuite { t.Fatalf("failed to create temp directory: %v", err) } - config := &server.Config{} + config := &types.Config{} _, err = util.ReadJson("testdata/management.json", config) if err != nil { t.Fatalf("failed to read management.json: %v", err) @@ -156,7 +156,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie func startServer( t *testing.T, - config *server.Config, + config *types.Config, dataDir string, testFile string, ) (*grpc.Server, net.Listener) { @@ -300,6 +300,10 @@ func TestSyncNewPeerConfiguration(t *testing.T) { Protocol: mgmtProto.HostConfig_UDP, } + expectedRelayHost := &mgmtProto.RelayConfig{ + Urls: []string{"rel://test.com:3535"}, + } + assert.NotNil(t, resp.NetbirdConfig) assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig) assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig) @@ -307,6 +311,8 @@ func TestSyncNewPeerConfiguration(t *testing.T) { actualTURN := resp.NetbirdConfig.Turns[0] assert.Greater(t, len(actualTURN.User), 0) assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost) + assert.Equal(t, len(resp.NetbirdConfig.Relay.Urls), 1) + assert.Equal(t, resp.NetbirdConfig.Relay.Urls, expectedRelayHost.Urls) assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 03cb21af1..9a3b22e51 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" nbversion "github.com/netbirdio/netbird/version" ) @@ -49,7 +48,7 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { GetAllAccounts(ctx context.Context) []*types.Account - GetStoreEngine() store.Engine + GetStoreEngine() types.Engine } // ConnManager peer connection manager that holds state for current active connections diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 4894c1ac4..de6686400 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -10,7 +10,6 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -205,8 +204,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { } // GetStoreEngine returns FileStoreEngine -func (mockDatasource) GetStoreEngine() store.Engine { - return store.FileStoreEngine +func (mockDatasource) GetStoreEngine() types.Engine { + return types.FileStoreEngine } // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties @@ -304,7 +303,7 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) } - if properties["store_engine"] != store.FileStoreEngine { + if properties["store_engine"] != types.FileStoreEngine { t.Errorf("expected JsonFile, got %s", properties["store_engine"]) } diff --git a/management/server/peer.go b/management/server/peer.go index d976ce68e..4e70fe6e3 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1213,7 +1213,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } - update := toSyncResponse(ctx, &Config{}, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1282,7 +1282,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - update := toSyncResponse(ctx, &Config{}, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 64bf5a73b..0b91ff37d 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -723,7 +723,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } } -func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { +func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) { b.Helper() manager, err := createManager(b) @@ -998,6 +998,53 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } +func TestUpdateAccountPeers(t *testing.T) { + testCases := []struct { + name string + peers int + groups int + }{ + {"Small", 50, 1}, + {"Medium", 500, 1}, + {"Large", 1000, 1}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) + if err != nil { + t.Fatalf("Failed to setup test account manager: %v", err) + } + + ctx := context.Background() + + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + t.Fatalf("Failed to get account: %v", err) + } + + peerChannels := make(map[string]chan *UpdateMessage) + + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + + manager.peersUpdateManager.peerChannels = peerChannels + manager.UpdateAccountPeers(ctx, account.Id) + + for _, channel := range peerChannels { + update := <-channel + assert.Nil(t, update.Update.NetbirdConfig) + assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) + assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) + } + }) + } +} + func TestToSyncResponse(t *testing.T) { _, ipnet, err := net.ParseCIDR("192.168.1.0/24") if err != nil { @@ -1008,16 +1055,16 @@ func TestToSyncResponse(t *testing.T) { t.Fatal(err) } - config := &Config{ - Signal: &Host{ + config := &types.Config{ + Signal: &types.Host{ Proto: "https", URI: "signal.uri", Username: "", Password: "", }, - Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, - TURNConfig: &TURNConfig{ - Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, + Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}}, + TURNConfig: &types.TURNConfig{ + Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}}, }, } peer := &nbpeer.Peer{ diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index 4c9134e41..3b95164f5 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -260,6 +260,6 @@ func (s *FileStore) Close(ctx context.Context) error { } // GetStoreEngine returns FileStoreEngine -func (s *FileStore) GetStoreEngine() Engine { - return FileStoreEngine +func (s *FileStore) GetStoreEngine() types.Engine { + return types.FileStoreEngine } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 01823c797..9bdf51bd9 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -55,7 +55,7 @@ type SqlStore struct { globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int - storeEngine Engine + storeEngine types.Engine } type installation struct { @@ -66,7 +66,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -77,7 +77,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t conns = runtime.NumCPU() } - if storeEngine == SqliteStoreEngine { + if storeEngine == types.SqliteStoreEngine { if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } @@ -105,7 +105,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t } func GetKeyQueryCondition(s *SqlStore) string { - if s.storeEngine == MysqlStoreEngine { + if s.storeEngine == types.MysqlStoreEngine { return mysqlKeyQueryCondition } return keyQueryCondition @@ -970,7 +970,7 @@ func (s *SqlStore) Close(_ context.Context) error { } // GetStoreEngine returns underlying store engine -func (s *SqlStore) GetStoreEngine() Engine { +func (s *SqlStore) GetStoreEngine() types.Engine { return s.storeEngine } @@ -988,7 +988,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe return nil, err } - return NewSqlStore(ctx, db, SqliteStoreEngine, metrics) + return NewSqlStore(ctx, db, types.SqliteStoreEngine, metrics) } // NewPostgresqlStore creates a new Postgres store. @@ -998,7 +998,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe return nil, err } - return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) + return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics) } // NewMysqlStore creates a new MySQL store. @@ -1008,7 +1008,7 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics return nil, err } - return NewSqlStore(ctx, db, MysqlStoreEngine, metrics) + return NewSqlStore(ctx, db, types.MysqlStoreEngine, metrics) } func getGormConfig() *gorm.Config { @@ -1517,9 +1517,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) switch s.storeEngine { - case PostgresStoreEngine: + case types.PostgresStoreEngine: query = query.Order("json_array_length(peers::json) DESC") - case MysqlStoreEngine: + case types.MysqlStoreEngine: query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") default: query = query.Order("json_array_length(peers) DESC") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 0d67ca719..589e727e9 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -297,7 +297,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -628,7 +628,7 @@ func TestMigrate(t *testing.T) { } // TODO: figure out why this fails on postgres - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) @@ -737,7 +737,7 @@ func TestPostgresql_NewStore(t *testing.T) { t.Skip("skip CI tests on darwin and windows") } - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -752,7 +752,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { t.Skip("skip CI tests on darwin and windows") } - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -825,7 +825,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { t.Skip("skip CI tests on darwin and windows") } - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -901,7 +901,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("skip CI tests on darwin and windows") } - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -921,7 +921,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { t.Skip("skip CI tests on darwin and windows") } - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -935,7 +935,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } func TestSqlite_GetTakenIPs(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { @@ -980,7 +980,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) if err != nil { return @@ -1022,7 +1022,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { } func TestSqlite_GetAccountNetwork(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { @@ -1045,7 +1045,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { } func TestSqlite_GetSetupKeyBySecret(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { @@ -1070,7 +1070,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { @@ -1106,7 +1106,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { } func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { @@ -1211,7 +1211,7 @@ func TestSqlite_GetGroupByName(t *testing.T) { } func Test_DeleteSetupKeySuccessfully(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1227,7 +1227,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { } func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + t.Setenv("NETBIRD_STORE_ENGINE", string(types.SqliteStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) diff --git a/management/server/store/store.go b/management/server/store/store.go index 9ff0c5636..1975f11b2 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -164,7 +164,7 @@ type Store interface { Close(ctx context.Context) error // GetStoreEngine should return Engine of the current store implementation. // This is also a method of metrics.DataSource interface. - GetStoreEngine() Engine + GetStoreEngine() types.Engine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) @@ -187,44 +187,37 @@ type Store interface { GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) } -type Engine string - const ( - FileStoreEngine Engine = "jsonfile" - SqliteStoreEngine Engine = "sqlite" - PostgresStoreEngine Engine = "postgres" - MysqlStoreEngine Engine = "mysql" - postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" ) -var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} +var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine} -func getStoreEngineFromEnv() Engine { +func getStoreEngineFromEnv() types.Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") if !ok { return "" } - value := Engine(strings.ToLower(kind)) + value := types.Engine(strings.ToLower(kind)) if slices.Contains(supportedEngines, value) { return value } - return SqliteStoreEngine + return types.SqliteStoreEngine } // getStoreEngine determines the store engine to use. // If no engine is specified, it attempts to retrieve it from the environment. // If still not specified, it defaults to using SQLite. // Additionally, it handles the migration from a JSON store file to SQLite if applicable. -func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { +func getStoreEngine(ctx context.Context, dataDir string, kind types.Engine) types.Engine { if kind == "" { kind = getStoreEngineFromEnv() if kind == "" { - kind = SqliteStoreEngine + kind = types.SqliteStoreEngine // Migrate if it is the first run with a JSON file existing and no SQLite file present jsonStoreFile := filepath.Join(dataDir, storeFileName) @@ -236,7 +229,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { // Attempt to migrate from JSON store to SQLite if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) - kind = FileStoreEngine + kind = types.FileStoreEngine } } } @@ -246,7 +239,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { } // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics -func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { +func NewStore(ctx context.Context, kind types.Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { kind = getStoreEngine(ctx, dataDir, kind) if err := checkFileStoreEngine(kind, dataDir); err != nil { @@ -254,13 +247,13 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr } switch kind { - case SqliteStoreEngine: + case types.SqliteStoreEngine: log.WithContext(ctx).Info("using SQLite store engine") return NewSqliteStore(ctx, dataDir, metrics) - case PostgresStoreEngine: + case types.PostgresStoreEngine: log.WithContext(ctx).Info("using Postgres store engine") return newPostgresStore(ctx, metrics) - case MysqlStoreEngine: + case types.MysqlStoreEngine: log.WithContext(ctx).Info("using MySQL store engine") return newMysqlStore(ctx, metrics) default: @@ -268,12 +261,12 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr } } -func checkFileStoreEngine(kind Engine, dataDir string) error { - if kind == FileStoreEngine { +func checkFileStoreEngine(kind types.Engine, dataDir string) error { + if kind == types.FileStoreEngine { storeFile := filepath.Join(dataDir, storeFileName) if util.FileExists(storeFile) { return fmt.Errorf("%s is not supported. Please refer to the documentation for migrating to SQLite: "+ - "https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", FileStoreEngine) + "https://docs.netbird.io/selfhosted/sqlite-store#migrating-from-json-store-to-sq-lite-store", types.FileStoreEngine) } } return nil @@ -326,7 +319,7 @@ func getMigrations(ctx context.Context) []migrationFunc { func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) { kind := getStoreEngineFromEnv() if kind == "" { - kind = SqliteStoreEngine + kind = types.SqliteStoreEngine } storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) @@ -348,7 +341,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( } } - store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil) + store, err := NewSqlStore(ctx, db, types.SqliteStoreEngine, nil) if err != nil { return nil, nil, fmt.Errorf("failed to create test store: %v", err) } @@ -394,13 +387,13 @@ func addAllGroupToAccount(ctx context.Context, store Store) error { return nil } -func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { +func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) (Store, func(), error) { var cleanup func() var err error switch kind { - case PostgresStoreEngine: + case types.PostgresStoreEngine: store, cleanup, err = newReusedPostgresStore(ctx, store, kind) - case MysqlStoreEngine: + case types.MysqlStoreEngine: store, cleanup, err = newReusedMysqlStore(ctx, store, kind) default: cleanup = func() { @@ -419,7 +412,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store return store, closeConnection, nil } -func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) { +func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) { if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { var err error _, err = testutil.CreatePostgresTestContainer() @@ -451,7 +444,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) ( return store, cleanup, nil } -func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) { +func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) { if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { var err error _, err = testutil.CreateMysqlTestContainer() @@ -483,7 +476,7 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*Sq return store, cleanup, nil } -func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) { +func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil { @@ -493,9 +486,9 @@ func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), err var err error cleanup := func() { switch engine { - case PostgresStoreEngine: + case types.PostgresStoreEngine: err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error - case MysqlStoreEngine: + case types.MysqlStoreEngine: // err = killMySQLConnections(dsn, dbName) err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error } diff --git a/management/server/testdata/management.json b/management/server/testdata/management.json index f797a7d2b..1a48fbace 100644 --- a/management/server/testdata/management.json +++ b/management/server/testdata/management.json @@ -20,6 +20,11 @@ "Secret": "c29tZV9wYXNzd29yZA==", "TimeBasedCredentials": true }, + "Relay":{ + "Addresses":["rel://test.com:3535"], + "CredentialsTTL":"2h", + "Secret":"netbird" + }, "Signal": { "Proto": "http", "URI": "signal.netbird.io:10000", diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go index f8238aa16..ee9fee376 100644 --- a/management/server/token_mgr.go +++ b/management/server/token_mgr.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" auth "github.com/netbirdio/netbird/relay/auth/hmac" authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" @@ -32,8 +33,8 @@ type SecretsManager interface { // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server type TimeBasedAuthSecretsManager struct { mux sync.Mutex - turnCfg *TURNConfig - relayCfg *Relay + turnCfg *types.TURNConfig + relayCfg *types.Relay turnHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator updateManager *PeersUpdateManager @@ -44,7 +45,7 @@ type TimeBasedAuthSecretsManager struct { type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager { mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -221,7 +222,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont } } - m.extendNetbirdConfig(ctx, accountID, update) + m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) @@ -245,17 +246,17 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac }, } - m.extendNetbirdConfig(ctx, accountID, update) + m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) } -func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, accountID string, update *proto.SyncResponse) { +func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) } - integrationsConfig.ExtendNetBirdConfig(update.NetbirdConfig, extraSettings) + integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings) } diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index c07e40418..b2184717d 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -19,8 +19,8 @@ import ( "github.com/netbirdio/netbird/util" ) -var TurnTestHost = &Host{ - Proto: UDP, +var TurnTestHost = &types.Host{ + Proto: types.UDP, URI: "turn:turn.netbird.io:77777", Username: "username", Password: "", @@ -31,7 +31,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { secret := "some_secret" peersManager := NewPeersUpdateManager(nil) - rc := &Relay{ + rc := &types.Relay{ Addresses: []string{"localhost:0"}, CredentialsTTL: ttl, Secret: secret, @@ -41,10 +41,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ + tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ CredentialsTTL: ttl, Secret: secret, - Turns: []*Host{TurnTestHost}, + Turns: []*types.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager) @@ -81,7 +81,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { peer := "some_peer" updateChannel := peersManager.CreateChannel(context.Background(), peer) - rc := &Relay{ + rc := &types.Relay{ Addresses: []string{"localhost:0"}, CredentialsTTL: ttl, Secret: secret, @@ -92,10 +92,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() - tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ + tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ CredentialsTTL: ttl, Secret: secret, - Turns: []*Host{TurnTestHost}, + Turns: []*types.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager) @@ -184,7 +184,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { peersManager := NewPeersUpdateManager(nil) peer := "some_peer" - rc := &Relay{ + rc := &types.Relay{ Addresses: []string{"localhost:0"}, CredentialsTTL: ttl, Secret: secret, @@ -194,10 +194,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ + tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ CredentialsTTL: ttl, Secret: secret, - Turns: []*Host{TurnTestHost}, + Turns: []*types.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager) diff --git a/management/server/config.go b/management/server/types/config.go similarity index 98% rename from management/server/config.go rename to management/server/types/config.go index ce2ff4d16..d2e418264 100644 --- a/management/server/config.go +++ b/management/server/types/config.go @@ -1,10 +1,9 @@ -package server +package types import ( "net/netip" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -30,6 +29,8 @@ const ( DefaultDeviceAuthFlowScope string = "openid" ) +var MgmtConfigPath string + // Config of the Management service type Config struct { Stuns []*Host @@ -76,6 +77,7 @@ type TURNConfig struct { Turns []*Host } +// Relay configuration type type Relay struct { Addresses []string CredentialsTTL util.Duration @@ -156,7 +158,7 @@ type ProviderConfig struct { // StoreConfig contains Store configuration type StoreConfig struct { - Engine store.Engine + Engine Engine } // ReverseProxy contains reverse proxy configuration in front of management. diff --git a/management/server/types/store.go b/management/server/types/store.go new file mode 100644 index 000000000..2ca4383b2 --- /dev/null +++ b/management/server/types/store.go @@ -0,0 +1,10 @@ +package types + +type Engine string + +const ( + PostgresStoreEngine Engine = "postgres" + FileStoreEngine Engine = "jsonfile" + SqliteStoreEngine Engine = "sqlite" + MysqlStoreEngine Engine = "mysql" +)