[management] Use DI containers for server bootstrapping (#4343)

This commit is contained in:
Pascal Fischer
2025-08-15 17:14:48 +02:00
committed by GitHub
parent ab853ac2a5
commit b3056d0937
18 changed files with 894 additions and 519 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -35,7 +36,7 @@ import (
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
t.Helper() t.Helper()
config := &types.Config{} config := &config.Config{}
_, err := util.ReadJson("../testdata/management.json", config) _, err := util.ReadJson("../testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -70,7 +71,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *config.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@@ -44,8 +45,6 @@ import (
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -55,8 +54,10 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
@@ -1514,15 +1515,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper() t.Helper()
config := &types.Config{ config := &config.Config{
Stuns: []*types.Host{}, Stuns: []*config.Host{},
TURNConfig: &types.TURNConfig{}, TURNConfig: &config.TURNConfig{},
Relay: &types.Relay{ Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"}, Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour}, CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222", Secret: "222222222222222222",
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "localhost:10000", URI: "localhost:10000",
}, },

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -32,7 +33,6 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
@@ -267,10 +267,10 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Helper() t.Helper()
dataDir := t.TempDir() dataDir := t.TempDir()
config := &types.Config{ config := &config.Config{
Stuns: []*types.Host{}, Stuns: []*config.Host{},
TURNConfig: &types.TURNConfig{}, TURNConfig: &config.TURNConfig{},
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: signalAddr, URI: signalAddr,
}, },

View File

@@ -2,88 +2,40 @@ package cmd
import ( import (
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"net"
"net/http" "net/http"
"net/netip"
"net/url" "net/url"
"os" "os"
"os/signal"
"path" "path"
"slices"
"strings" "strings"
"time" "syscall"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/internals/server"
"github.com/netbirdio/netbird/management/server" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
// ManagementLegacyPort is the port that was used before by the Management gRPC server. var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
// It is used for backward compatibility now. return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
const ManagementLegacyPort = 33073 }
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
newServer = fn
}
var ( var (
mgmtPort int config *nbconfig.Config
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
config *types.Config
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
mgmtCmd = &cobra.Command{ mgmtCmd = &cobra.Command{
Use: "management", Use: "management",
@@ -102,9 +54,9 @@ var (
// detect whether user specified a port // detect whether user specified a port
userPort := cmd.Flag("port").Changed userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, types.MgmtConfigPath) config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", types.MgmtConfigPath, err) return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
} }
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed { if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
@@ -151,356 +103,38 @@ var (
return fmt.Errorf("failed creating datadir: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating datadir: %s: %v", config.Datadir, err)
} }
} }
appMetrics, err := telemetry.NewDefaultAppMetrics(cmd.Context())
if err != nil {
return err
}
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
integrationMetrics, err := integrations.InitIntegrationMetrics(ctx, appMetrics)
if err != nil {
return err
}
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(appMetrics)
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
}
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
return fmt.Errorf("initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key { srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
log.WithContext(ctx).Infof("update config with activity store key") go func() {
config.DataStoreEncryptionKey = key if err := srv.Start(cmd.Context()); err != nil {
err := updateMgmtConfig(ctx, types.MgmtConfigPath, config) log.Fatalf("Server error: %v", err)
}
}()
stopChan := make(chan os.Signal, 1)
signal.Notify(stopChan, os.Interrupt, syscall.SIGTERM)
select {
case <-stopChan:
log.Info("Received shutdown signal, stopping server...")
err = srv.Stop()
if err != nil { if err != nil {
return fmt.Errorf("write out store encryption key: %s", err) log.Errorf("Failed to stop server gracefully: %v", err)
} }
case err := <-srv.Errors():
log.Fatalf("Server stopped unexpectedly: %v", err)
} }
geo, err := geolocation.NewGeolocation(ctx, config.Datadir, !disableGeoliteUpdate)
if err != nil {
log.WithContext(ctx).Warnf("could not initialize geolocation service. proceeding without geolocation support: %v", err)
} else {
log.WithContext(ctx).Infof("geolocation service has been initialized from %s", config.Datadir)
}
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("initialize integrated peer validator: %v", err)
}
permissionsManager := integrations.InitPermissionsManager(store)
userManager := users.NewManager(store)
extraSettingsManager := integrations.NewManager(eventStore)
settingsManager := settings.NewManager(store, userManager, extraSettingsManager, permissionsManager)
peersManager := peers.NewManager(store, permissionsManager)
proxyController := integrations.NewController(store)
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics, proxyController, settingsManager, permissionsManager, config.DisableDefaultPolicy)
if err != nil {
return fmt.Errorf("build default manager: %v", err)
}
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
realip.WithTrustedPeers(trustedPeers),
realip.WithTrustedProxies(trustedHTTPProxies),
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
var certManager *autocert.Manager
var tlsConfig *tls.Config
tlsEnabled := false
if config.HttpConfig.LetsEncryptDomain != "" {
certManager, err = encryption.CreateCertManager(config.Datadir, config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
tlsEnabled = true
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
transportCredentials := credentials.NewTLS(tlsConfig)
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
tlsEnabled = true
}
authManager := auth.NewManager(store,
config.HttpConfig.AuthIssuer,
config.HttpConfig.AuthAudience,
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.AuthUserIDClaim,
config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, authManager, appMetrics, integratedPeerValidator, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager, authManager, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(ctx, store)
if err != nil {
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
if !disableMetrics {
idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run(ctx)
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
var listener net.Listener
if certManager != nil {
// a call to certManager.Listener() always creates a new listener so we do it once
cml := certManager.Listener()
if mgmtPort == 443 {
// CertManager, HTTP and gRPC API all on the same port
rootHandler = certManager.HTTPHandler(rootHandler)
listener = cml
} else {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), certManager.TLSConfig())
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
}
} else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
} else {
listener, err = net.Listen("tcp", fmt.Sprintf(":%d", mgmtPort))
if err != nil {
return fmt.Errorf("failed creating TCP listener on port %d: %v", mgmtPort, err)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
update := version.NewUpdate("nb/management")
update.SetDaemonVersion(version.NetbirdVersion())
update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
defer update.StopWatch()
SetupCloseHandler()
<-stopCh
integratedPeerValidator.Stop(ctx)
if geo != nil {
_ = geo.Stop()
}
ephemeralManager.Stop()
_ = appMetrics.Close()
_ = listener.Close()
if certManager != nil {
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close(ctx)
_ = eventStore.Close(ctx)
log.WithContext(ctx).Infof("stopped Management Service")
return nil return nil
}, },
} }
) )
func unaryInterceptor( func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
ctx context.Context, loadedConfig := &nbconfig.Config{}
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
select {
case stopCh <- 1:
log.WithContext(ctx).Error(msg)
default:
// stop has been already called, nothing to report
}
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
go func() {
err := grpcServer.Serve(listener)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
}
}()
return listener, nil
}
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
go func() {
err := http.Serve(httpListener, handler)
if err != nil {
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
}
}()
}
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() {
var err error
if tlsEnabled {
err = http.Serve(listener, handler)
} else {
// the following magic is needed to support HTTP2 without TLS
// and still share a single port between gRPC and HTTP APIs
h1s := &http.Server{
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
err = h1s.Serve(listener)
}
if err != nil {
select {
case stopCh <- 1:
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
default:
// stop has been already called, nothing to report
}
}
}()
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
gRPCHandler.ServeHTTP(writer, request)
} else {
httpHandler.ServeHTTP(writer, request)
}
})
}
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config, error) {
loadedConfig := &types.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -535,7 +169,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(types.NONE)) { if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
@@ -552,7 +186,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" { if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = types.DefaultDeviceAuthFlowScope loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
} }
} }
@@ -573,10 +207,6 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*types.Config,
return loadedConfig, err return loadedConfig, err
} }
func updateMgmtConfig(ctx context.Context, path string, config *types.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
// OIDCConfigResponse used for parsing OIDC config response // OIDCConfigResponse used for parsing OIDC config response
type OIDCConfigResponse struct { type OIDCConfigResponse struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
@@ -619,25 +249,6 @@ func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigRespon
return config, nil return config, nil
} }
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {
@@ -649,7 +260,7 @@ func handleRebrand(cmd *cobra.Command) error {
} }
} }
} }
if types.MgmtConfigPath == defaultMgmtConfig { if nbconfig.MgmtConfigPath == defaultMgmtConfig {
if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) { if migrateToNetbird(oldDefaultMgmtConfig, defaultMgmtConfig) {
cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir) cmd.Printf("will copy Config dir %s and its content to %s\n", oldDefaultMgmtConfigDir, defaultMgmtConfigDir)
err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir) err = cpDir(oldDefaultMgmtConfigDir, defaultMgmtConfigDir)

View File

@@ -2,12 +2,10 @@ package cmd
import ( import (
"fmt" "fmt"
"os"
"os/signal"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -27,6 +25,12 @@ var (
disableGeoliteUpdate bool disableGeoliteUpdate bool
idpSignKeyRefreshEnabled bool idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool userDeleteFromIDPEnabled bool
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -42,8 +46,6 @@ var (
Long: "", Long: "",
SilenceUsage: true, SilenceUsage: true,
} }
// Execution control channel for stopCh signal
stopCh chan int
) )
// Execute executes the root command. // Execute executes the root command.
@@ -52,11 +54,10 @@ func Execute() error {
} }
func init() { func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&types.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain) mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.") mgmtCmd.Flags().BoolVar(&disableSingleAccMode, "disable-single-account-mode", false, "If set to true, disables single account mode. The --single-account-mode-domain property will be ignored and every new user will have a separate NetBird account.")
@@ -80,15 +81,3 @@ func init() {
rootCmd.AddCommand(migrationCmd) rootCmd.AddCommand(migrationCmd)
} }
// SetupCloseHandler handles SIGTERM signal and exits with success
func SetupCloseHandler() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for range c {
fmt.Println("\r- Ctrl+C pressed in Terminal")
stopCh <- 0
}
}()
}

View File

@@ -0,0 +1,204 @@
package server
// @note this file includes all the lower level dependencies, db, http and grpc BaseServer, metrics, logger, etc.
import (
"context"
"crypto/tls"
"net/http"
"net/netip"
"slices"
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func (s *BaseServer) Metrics() telemetry.AppMetrics {
return Create(s, func() telemetry.AppMetrics {
appMetrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("error while creating app metrics: %s", err)
}
return appMetrics
})
}
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
if err != nil {
log.Fatalf("failed to create store: %v", err)
}
return store
})
}
func (s *BaseServer) EventStore() activity.Store {
return Create(s, func() activity.Store {
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
if err != nil {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
if s.config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update config with activity store key")
s.config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
if err != nil {
log.Fatalf("failed to update config with activity store: %v", err)
}
}
return eventStore
})
}
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
return httpAPIHandler
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
realip.WithTrustedPeers(trustedPeers),
realip.WithTrustedProxies(trustedHTTPProxies),
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
if err != nil {
log.Fatalf("cannot load TLS credentials: %v", err)
}
transportCredentials := credentials.NewTLS(tlsConfig)
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
return gRPCAPIHandler
})
}
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
if err != nil {
return nil, err
}
// NewDefaultAppMetrics the credentials and return it
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}

View File

@@ -1,10 +1,11 @@
package types package config
import ( import (
"net/netip" "net/netip"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -166,7 +167,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration // StoreConfig contains Store configuration
type StoreConfig struct { type StoreConfig struct {
Engine Engine Engine types.Engine
} }
// ReverseProxy contains reverse proxy configuration in front of management. // ReverseProxy contains reverse proxy configuration in front of management.

View File

@@ -0,0 +1,55 @@
package server
import "fmt"
// Create a dependency and add it to the BaseServer's container. A string key identifier will be based on its type definition.
func Create[T any](s Server, createFunc func() T) T {
result, _ := maybeCreate(s, createFunc)
return result
}
// CreateNamed is the same as Create but will suffix the dependency string key identifier with a custom name.
// Useful if you want to have multiple named instances of the same object type.
func CreateNamed[T any](s Server, name string, createFunc func() T) T {
result, _ := maybeCreateNamed(s, name, createFunc)
return result
}
// Inject lets you override a specific service from outside the BaseServer itself.
// This is useful for tests
func Inject[T any](c Server, thing T) {
_, _ = maybeCreate(c, func() T {
return thing
})
}
// InjectNamed is like Inject() but with a custom name.
func InjectNamed[T any](c Server, name string, thing T) {
_, _ = maybeCreateKeyed(c, name, func() T {
return thing
})
}
func maybeCreate[T any](s Server, createFunc func() T) (result T, isNew bool) {
key := fmt.Sprintf("%T", (*T)(nil))[1:]
return maybeCreateKeyed(s, key, createFunc)
}
func maybeCreateNamed[T any](s Server, name string, createFunc func() T) (result T, isNew bool) {
key := fmt.Sprintf("%T:%s", (*T)(nil), name)[1:]
return maybeCreateKeyed(s, key, createFunc)
}
func maybeCreateKeyed[T any](s Server, key string, createFunc func() T) (result T, isNew bool) {
if t, ok := s.GetContainer(key); ok {
return t.(T), false
}
t := createFunc()
s.SetContainer(key, t)
return t, true
}

View File

@@ -0,0 +1,59 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
)
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
return Create(s, func() *server.PeersUpdateManager {
return server.NewPeersUpdateManager(s.Metrics())
})
}
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}
return integratedPeerValidator
})
}
func (s *BaseServer) ProxyController() port_forwarding.Controller {
return Create(s, func() port_forwarding.Controller {
return integrations.NewController(s.Store())
})
}
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager {
return Create(s, func() *server.TimeBasedAuthSecretsManager {
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
})
}
func (s *BaseServer) AuthManager() auth.Manager {
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
s.config.HttpConfig.AuthIssuer,
s.config.HttpConfig.AuthAudience,
s.config.HttpConfig.AuthKeysLocation,
s.config.HttpConfig.AuthUserIDClaim,
s.config.GetAuthAudiences(),
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
})
}
func (s *BaseServer) EphemeralManager() *server.EphemeralManager {
return Create(s, func() *server.EphemeralManager {
return server.NewEphemeralManager(s.Store(), s.AccountManager())
})
}

View File

@@ -0,0 +1,108 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
return Create(s, func() geolocation.Geolocation {
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
if err != nil {
log.Warnf("could not initialize geolocation service. proceeding without geolocation support: %v", err)
} else {
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
}
return geo
})
}
func (s *BaseServer) PermissionsManager() permissions.Manager {
return Create(s, func() permissions.Manager {
return integrations.InitPermissionsManager(s.Store())
})
}
func (s *BaseServer) UsersManager() users.Manager {
return Create(s, func() users.Manager {
return users.NewManager(s.Store())
})
}
func (s *BaseServer) SettingsManager() settings.Manager {
return Create(s, func() settings.Manager {
extraSettingsManager := integrations.NewManager(s.EventStore())
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
})
}
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
return peers.NewManager(s.Store(), s.PermissionsManager())
})
}
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
return accountManager
})
}
func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
if s.config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
}
}
return idpManager
})
}
func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
})
}
func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
})
}
func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}

View File

@@ -0,0 +1,340 @@
package server
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
const ManagementLegacyPort = 33073
type Server interface {
Start(ctx context.Context) error
Stop() error
Errors() <-chan error
GetContainer(key string) (any, bool)
SetContainer(key string, container any)
}
// Server holds the HTTP BaseServer instance.
// Add any additional fields you need, such as database connections, config, etc.
type BaseServer struct {
// config holds the server configuration
config *nbconfig.Config
// container of dependencies, each dependency is identified by a unique string.
container map[string]any
// AfterInit is a function that will be called after the server is initialized
afterInit []func(s *BaseServer)
disableMetrics bool
dnsDomain string
disableGeoliteUpdate bool
userDeleteFromIDPEnabled bool
mgmtSingleAccModeDomain string
mgmtMetricsPort int
mgmtPort int
listener net.Listener
certManager *autocert.Manager
update *version.Update
errCh chan error
wg sync.WaitGroup
cancel context.CancelFunc
}
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
return &BaseServer{
config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
disableMetrics: disableMetrics,
disableGeoliteUpdate: disableGeoliteUpdate,
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
mgmtPort: mgmtPort,
mgmtMetricsPort: mgmtMetricsPort,
}
}
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {
s.afterInit = append(s.afterInit, fn)
}
// Start begins listening for HTTP requests on the configured address
func (s *BaseServer) Start(ctx context.Context) error {
srvCtx, cancel := context.WithCancel(ctx)
s.cancel = cancel
s.errCh = make(chan error, 4)
s.PeersManager()
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err)
}
s.EphemeralManager().LoadInitialPeers(srvCtx)
var tlsConfig *tls.Config
tlsEnabled := false
if s.config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
tlsEnabled = true
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
}
tlsEnabled = true
}
installationID, err := getInstallationID(srvCtx, s.Store())
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
}
if !s.disableMetrics {
idpManager := "disabled"
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
idpManager = s.config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
go metricsWorker.Run(srvCtx)
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)
if err != nil {
return err
}
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
cml := s.certManager.Listener()
if s.mgmtPort == 443 {
// CertManager, HTTP and gRPC API all on the same port
rootHandler = s.certManager.HTTPHandler(rootHandler)
s.listener = cml
} else {
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), s.certManager.TLSConfig())
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
}
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
s.serveHTTP(ctx, cml, s.certManager.HTTPHandler(nil))
}
case tlsConfig != nil:
s.listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort), tlsConfig)
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", s.mgmtPort, err)
}
default:
s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.mgmtPort))
if err != nil {
return fmt.Errorf("failed creating TCP listener on port %d: %v", s.mgmtPort, err)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
s.update = version.NewUpdate("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
return nil
}
// Stop attempts a graceful shutdown, waiting up to 5 seconds for active connections to finish
func (s *BaseServer) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.IntegratedValidator().Stop(ctx)
if s.GeoLocationManager() != nil {
_ = s.GeoLocationManager().Stop()
}
s.EphemeralManager().Stop()
_ = s.Metrics().Close()
if s.listener != nil {
_ = s.listener.Close()
}
if s.certManager != nil {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {
s.update.StopWatch()
}
select {
case <-s.Errors():
log.WithContext(ctx).Infof("stopped Management Service")
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Done returns a channel that is closed when the server stops
func (s *BaseServer) Errors() <-chan error {
return s.errCh
}
// GetContainer retrieves a dependency from the BaseServer's container by its key
func (s *BaseServer) GetContainer(key string) (any, bool) {
container, exists := s.container[key]
return container, exists
}
// SetContainer stores a dependency in the BaseServer's container with the specified key
func (s *BaseServer) SetContainer(key string, container any) {
if _, exists := s.container[key]; exists {
log.Tracef("container with key %s already exists", key)
return
}
s.container[key] = container
log.Tracef("container with key %s set successfully", key)
}
func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
if request.ProtoMajor == 2 && grpcHeader {
gRPCHandler.ServeHTTP(writer, request)
} else {
httpHandler.ServeHTTP(writer, request)
}
})
}
func (s *BaseServer) serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := grpcServer.Serve(listener)
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
return listener, nil
}
func (s *BaseServer) serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := http.Serve(httpListener, handler)
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
}
func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var err error
if tlsEnabled {
err = http.Serve(listener, handler)
} else {
// the following magic is needed to support HTTP2 without TLS
// and still share a single port between gRPC and HTTP APIs
h1s := &http.Server{
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
err = h1s.Serve(listener)
}
if ctx.Err() != nil {
return
}
select {
case s.errCh <- err:
default:
}
}()
}
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -44,7 +45,7 @@ type GRPCServer struct {
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *types.Config config *nbconfig.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
@@ -56,7 +57,7 @@ type GRPCServer struct {
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer( func NewServer(
ctx context.Context, ctx context.Context,
config *types.Config, config *nbconfig.Config,
accountManager account.Manager, accountManager account.Manager,
settingsManager settings.Manager, settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager, peersUpdateManager *PeersUpdateManager,
@@ -567,24 +568,24 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil return userID, nil
} }
func ToResponseProto(configProto types.Protocol) proto.HostConfig_Protocol { func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto { switch configProto {
case types.UDP: case nbconfig.UDP:
return proto.HostConfig_UDP return proto.HostConfig_UDP
case types.DTLS: case nbconfig.DTLS:
return proto.HostConfig_DTLS return proto.HostConfig_DTLS
case types.HTTP: case nbconfig.HTTP:
return proto.HostConfig_HTTP return proto.HostConfig_HTTP
case types.HTTPS: case nbconfig.HTTPS:
return proto.HostConfig_HTTPS return proto.HostConfig_HTTPS
case types.TCP: case nbconfig.TCP:
return proto.HostConfig_TCP return proto.HostConfig_TCP
default: default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto)) panic(fmt.Errorf("unexpected config protocol type %v", configProto))
} }
} }
func toNetbirdConfig(config *types.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil { if config == nil {
return nil return nil
} }
@@ -662,7 +663,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
} }
} }
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
@@ -799,7 +800,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
return nil, status.Error(codes.InvalidArgument, errMSG) return nil, status.Error(codes.InvalidArgument, errMSG)
} }
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(types.NONE) { if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available") return nil, status.Error(codes.NotFound, "no device authorization flow information available")
} }

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -95,21 +96,21 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) { func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*types.Host{{ Stuns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*types.Host{{ Turns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },
@@ -332,7 +333,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
inputFlow *types.DeviceAuthorizationFlow inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string expectedErrMSG string
@@ -347,9 +348,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}, },
{ {
name: "Testing Invalid Device Flow Provider Config", name: "Testing Invalid Device Flow Provider Config",
inputFlow: &types.DeviceAuthorizationFlow{ inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe", Provider: "NoNe",
ProviderConfig: types.ProviderConfig{ ProviderConfig: config.ProviderConfig{
ClientID: "test", ClientID: "test",
}, },
}, },
@@ -358,9 +359,9 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}, },
{ {
name: "Testing Full Device Flow Config", name: "Testing Full Device Flow Config",
inputFlow: &types.DeviceAuthorizationFlow{ inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted", Provider: "hosted",
ProviderConfig: types.ProviderConfig{ ProviderConfig: config.ProviderConfig{
ClientID: "test", ClientID: "test",
}, },
}, },
@@ -381,7 +382,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{ mgmtServer := &GRPCServer{
wgKey: testingServerKey, wgKey: testingServerKey,
config: &types.Config{ config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow, DeviceAuthorizationFlow: testCase.inputFlow,
}, },
} }
@@ -412,7 +413,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
} }
} }
func startManagementForTest(t *testing.T, testFile string, config *types.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
@@ -515,21 +516,21 @@ func testSyncStatusRace(t *testing.T) {
t.Skip() t.Skip()
dir := t.TempDir() dir := t.TempDir()
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*types.Host{{ Stuns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*types.Host{{ Turns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },
@@ -687,21 +688,21 @@ func Test_LoginPerformance(t *testing.T) {
t.Helper() t.Helper()
dir := t.TempDir() dir := t.TempDir()
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &types.Config{ mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &config.Config{
Stuns: []*types.Host{{ Stuns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "stun:stun.netbird.io:3468", URI: "stun:stun.netbird.io:3468",
}}, }},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
TimeBasedCredentials: false, TimeBasedCredentials: false,
CredentialsTTL: util.Duration{}, CredentialsTTL: util.Duration{},
Secret: "whatever", Secret: "whatever",
Turns: []*types.Host{{ Turns: []*config.Host{{
Proto: "udp", Proto: "udp",
URI: "turn:stun.netbird.io:3468", URI: "turn:stun.netbird.io:3468",
}}, }},
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "signal.netbird.io:10000", URI: "signal.netbird.io:10000",
}, },

View File

@@ -20,7 +20,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -30,6 +30,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -60,7 +61,7 @@ func setupTest(t *testing.T) *testSuite {
t.Fatalf("failed to create temp directory: %v", err) t.Fatalf("failed to create temp directory: %v", err)
} }
config := &types.Config{} config := &config.Config{}
_, err = util.ReadJson("testdata/management.json", config) _, err = util.ReadJson("testdata/management.json", config)
if err != nil { if err != nil {
t.Fatalf("failed to read management.json: %v", err) t.Fatalf("failed to read management.json: %v", err)
@@ -158,7 +159,7 @@ func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClie
func startServer( func startServer(
t *testing.T, t *testing.T,
config *types.Config, config *config.Config,
dataDir string, dataDir string,
testFile string, testFile string,
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {

View File

@@ -25,6 +25,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@@ -1063,16 +1064,16 @@ func TestToSyncResponse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
config := &types.Config{ config := &config.Config{
Signal: &types.Host{ Signal: &config.Host{
Proto: "https", Proto: "https",
URI: "signal.uri", URI: "signal.uri",
Username: "", Username: "",
Password: "", Password: "",
}, },
Stuns: []*types.Host{{URI: "stun.uri", Proto: types.UDP}}, Stuns: []*config.Host{{URI: "stun.uri", Proto: config.UDP}},
TURNConfig: &types.TURNConfig{ TURNConfig: &config.TURNConfig{
Turns: []*types.Host{{URI: "turn.uri", Proto: types.UDP, Username: "turn-user", Password: "turn-pass"}}, Turns: []*config.Host{{URI: "turn.uri", Proto: config.UDP, Username: "turn-user", Password: "turn-pass"}},
}, },
} }
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{

View File

@@ -12,9 +12,9 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
@@ -33,8 +33,8 @@ type SecretsManager interface {
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct { type TimeBasedAuthSecretsManager struct {
mux sync.Mutex mux sync.Mutex
turnCfg *types.TURNConfig turnCfg *nbconfig.TURNConfig
relayCfg *types.Relay relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
@@ -46,7 +46,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{ mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager, updateManager: updateManager,
turnCfg: turnCfg, turnCfg: turnCfg,

View File

@@ -13,6 +13,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -20,8 +21,8 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
var TurnTestHost = &types.Host{ var TurnTestHost = &config.Host{
Proto: types.UDP, Proto: config.UDP,
URI: "turn:turn.netbird.io:77777", URI: "turn:turn.netbird.io:77777",
Username: "username", Username: "username",
Password: "", Password: "",
@@ -32,7 +33,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
secret := "some_secret" secret := "some_secret"
peersManager := NewPeersUpdateManager(nil) peersManager := NewPeersUpdateManager(nil)
rc := &types.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@@ -43,10 +44,10 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager) }, rc, settingsMockManager, groupsManager)
@@ -83,7 +84,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
peer := "some_peer" peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer) updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &types.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@@ -95,10 +96,10 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager) }, rc, settingsMockManager, groupsManager)
@@ -187,7 +188,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
peersManager := NewPeersUpdateManager(nil) peersManager := NewPeersUpdateManager(nil)
peer := "some_peer" peer := "some_peer"
rc := &types.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
@@ -198,10 +199,10 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager) }, rc, settingsMockManager, groupsManager)

View File

@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -27,9 +28,9 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -52,7 +53,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug") level, _ := log.ParseLevel("debug")
log.SetLevel(level) log.SetLevel(level)
config := &types.Config{} config := &config.Config{}
_, err := util.ReadJson("../../../management/server/testdata/management.json", config) _, err := util.ReadJson("../../../management/server/testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)