From a0436a201dbea322d4bd0d24c694da4d957674c7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 13 Jun 2025 17:20:28 +0200 Subject: [PATCH] Add flag for service env vars --- client/cmd/root.go | 9 +- client/cmd/service.go | 69 +++++++-- client/cmd/service_controller.go | 128 ++++++---------- client/cmd/service_installer.go | 244 ++++++++++++++++++++++--------- client/cmd/service_test.go | 221 ++++++++++++++++++++++++++++ 5 files changed, 504 insertions(+), 167 deletions(-) create mode 100644 client/cmd/service_test.go diff --git a/client/cmd/root.go b/client/cmd/root.go index 16e445f4d..bf6bf683e 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -71,7 +71,6 @@ var ( interfaceName string wireguardPort uint16 networkMonitor bool - serviceName string autoConnectDisabled bool extraIFaceBlackList []string anonymizeFlag bool @@ -153,9 +152,6 @@ func init() { rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(debugCmd) - serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service - serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service - networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) @@ -196,14 +192,13 @@ func SetupCloseHandler(ctx context.Context, cancel context.CancelFunc) { termCh := make(chan os.Signal, 1) signal.Notify(termCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) go func() { - done := ctx.Done() + defer cancel() select { - case <-done: + case <-ctx.Done(): case <-termCh: } log.Info("shutdown signal received") - cancel() }() } diff --git a/client/cmd/service.go b/client/cmd/service.go index 156e67d6d..6b1f3766a 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -2,11 +2,12 @@ package cmd import ( "context" + "fmt" "runtime" + "strings" "sync" "github.com/kardianos/service" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" @@ -14,6 +15,16 @@ import ( "github.com/netbirdio/netbird/client/server" ) +var serviceCmd = &cobra.Command{ + Use: "service", + Short: "manages Netbird service", +} + +var ( + serviceName string + serviceEnvVars []string +) + type program struct { ctx context.Context cancel context.CancelFunc @@ -22,12 +33,22 @@ type program struct { serverInstanceMu sync.Mutex } +func init() { + serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, installCmd, uninstallCmd, reconfigureCmd) + + serviceCmd.PersistentFlags().StringSliceVar(&serviceEnvVars, "service-env", nil, + `Sets extra environment variables for the service. `+ + `You can specify a comma-separated list of KEY=VALUE pairs. `+ + `E.g. --service-env LOG_LEVEL=debug,CUSTOM_VAR=value`, + ) +} + func newProgram(ctx context.Context, cancel context.CancelFunc) *program { ctx = internal.CtxInitState(ctx) return &program{ctx: ctx, cancel: cancel} } -func newSVCConfig() *service.Config { +func newSVCConfig() (*service.Config, error) { config := &service.Config{ Name: serviceName, DisplayName: "Netbird", @@ -36,23 +57,47 @@ func newSVCConfig() *service.Config { EnvVars: make(map[string]string), } + if len(serviceEnvVars) > 0 { + extraEnvs, err := parseServiceEnvVars(serviceEnvVars) + if err != nil { + return nil, fmt.Errorf("parse service environment variables: %w", err) + } + config.EnvVars = extraEnvs + } + if runtime.GOOS == "linux" { config.EnvVars["SYSTEMD_UNIT"] = serviceName } - return config + return config, nil } func newSVC(prg *program, conf *service.Config) (service.Service, error) { - s, err := service.New(prg, conf) - if err != nil { - log.Fatal(err) - return nil, err - } - return s, nil + return service.New(prg, conf) } -var serviceCmd = &cobra.Command{ - Use: "service", - Short: "manages Netbird service", +func parseServiceEnvVars(envVars []string) (map[string]string, error) { + envMap := make(map[string]string) + + for _, env := range envVars { + if env == "" { + continue + } + + parts := strings.SplitN(env, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid environment variable format: %s (expected KEY=VALUE)", env) + } + + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + + if key == "" { + return nil, fmt.Errorf("empty environment variable key in: %s", env) + } + + envMap[key] = value + } + + return envMap, nil } diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 5e3c63e57..5ba49096e 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -47,14 +47,13 @@ func (p *program) Start(svc service.Service) error { listen, err := net.Listen(split[0], split[1]) if err != nil { - return fmt.Errorf("failed to listen daemon interface: %w", err) + return fmt.Errorf("listen daemon interface: %w", err) } go func() { defer listen.Close() if split[0] == "unix" { - err = os.Chmod(split[1], 0666) - if err != nil { + if err := os.Chmod(split[1], 0666); err != nil { log.Errorf("failed setting daemon permissions: %v", split[1]) return } @@ -100,37 +99,49 @@ func (p *program) Stop(srv service.Service) error { return nil } +// Common setup for service control commands +func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) { + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(serviceCmd) + + cmd.SetOut(cmd.OutOrStdout()) + + if err := handleRebrand(cmd); err != nil { + return nil, err + } + + if err := util.InitLog(logLevel, logFile); err != nil { + return nil, fmt.Errorf("init log: %w", err) + } + + cfg, err := newSVCConfig() + if err != nil { + return nil, fmt.Errorf("create service config: %w", err) + } + + s, err := newSVC(newProgram(ctx, cancel), cfg) + if err != nil { + return nil, err + } + + return s, nil +} + var runCmd = &cobra.Command{ Use: "run", Short: "runs Netbird as service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - ctx, cancel := context.WithCancel(cmd.Context()) + SetupCloseHandler(ctx, cancel) SetupDebugHandler(ctx, nil, nil, nil, logFile) - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { return err } - err = s.Run() - if err != nil { - return err - } - return nil + + return s.Run() }, } @@ -138,31 +149,14 @@ var startCmd = &cobra.Command{ Use: "start", Short: "starts Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return err - } - ctx, cancel := context.WithCancel(cmd.Context()) - - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { - cmd.PrintErrln(err) return err } - err = s.Start() - if err != nil { - cmd.PrintErrln(err) - return err + + if err := s.Start(); err != nil { + return fmt.Errorf("start service: %w", err) } cmd.Println("Netbird service has been started") return nil @@ -173,29 +167,14 @@ var stopCmd = &cobra.Command{ Use: "stop", Short: "stops Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - ctx, cancel := context.WithCancel(cmd.Context()) - - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { return err } - err = s.Stop() - if err != nil { - return err + + if err := s.Stop(); err != nil { + return fmt.Errorf("stop service: %w", err) } cmd.Println("Netbird service has been stopped") return nil @@ -206,29 +185,14 @@ var restartCmd = &cobra.Command{ Use: "restart", Short: "restarts Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { - return err - } - - err = util.InitLog(logLevel, logFile) - if err != nil { - return fmt.Errorf("failed initializing log %v", err) - } - ctx, cancel := context.WithCancel(cmd.Context()) - - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := setupServiceControlCommand(cmd, ctx, cancel) if err != nil { return err } - err = s.Restart() - if err != nil { - return err + + if err := s.Restart(); err != nil { + return fmt.Errorf("restart service: %w", err) } cmd.Println("Netbird service has been restarted") return nil diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index c1d6308c6..648f53fc4 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -1,87 +1,117 @@ +//go:build !ios && !android package cmd import ( "context" + "fmt" "os" "path/filepath" "runtime" + "github.com/kardianos/service" "github.com/spf13/cobra" ) +// Common service command setup +func setupServiceCommand(cmd *cobra.Command) error { + SetFlagsFromEnvVars(rootCmd) + SetFlagsFromEnvVars(serviceCmd) + cmd.SetOut(cmd.OutOrStdout()) + return handleRebrand(cmd) +} + +// Build service arguments for install/reconfigure +func buildServiceArguments() []string { + args := []string{ + "service", + "run", + "--config", + configPath, + "--log-level", + logLevel, + "--daemon-addr", + daemonAddr, + } + + if managementURL != "" { + args = append(args, "--management-url", managementURL) + } + + if logFile != "" { + args = append(args, "--log-file", logFile) + } + + return args +} + +// Configure platform-specific service settings +func configurePlatformSpecificSettings(svcConfig *service.Config) error { + if runtime.GOOS == "linux" { + // Respected only by systemd systems + svcConfig.Dependencies = []string{"After=network.target syslog.target"} + + if logFile != "console" { + setStdLogPath := true + dir := filepath.Dir(logFile) + + if _, err := os.Stat(dir); err != nil { + if err = os.MkdirAll(dir, 0750); err != nil { + setStdLogPath = false + } + } + + if setStdLogPath { + svcConfig.Option["LogOutput"] = true + svcConfig.Option["LogDirectory"] = dir + } + } + } + + if runtime.GOOS == "windows" { + svcConfig.Option["OnFailure"] = "restart" + } + + return nil +} + +// Create fully configured service config for install/reconfigure +func createServiceConfigForInstall() (*service.Config, error) { + svcConfig, err := newSVCConfig() + if err != nil { + return nil, fmt.Errorf("create service config: %w", err) + } + + svcConfig.Arguments = buildServiceArguments() + if err = configurePlatformSpecificSettings(svcConfig); err != nil { + return nil, fmt.Errorf("configure platform-specific settings: %w", err) + } + + return svcConfig, nil +} + var installCmd = &cobra.Command{ Use: "install", Short: "installs Netbird service", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) - - cmd.SetOut(cmd.OutOrStdout()) - - err := handleRebrand(cmd) - if err != nil { + if err := setupServiceCommand(cmd); err != nil { return err } - svcConfig := newSVCConfig() - - svcConfig.Arguments = []string{ - "service", - "run", - "--config", - configPath, - "--log-level", - logLevel, - "--daemon-addr", - daemonAddr, - } - - if managementURL != "" { - svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) - } - - if logFile != "" { - svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) - } - - if runtime.GOOS == "linux" { - // Respected only by systemd systems - svcConfig.Dependencies = []string{"After=network.target syslog.target"} - - if logFile != "console" { - setStdLogPath := true - dir := filepath.Dir(logFile) - - _, err := os.Stat(dir) - if err != nil { - err = os.MkdirAll(dir, 0750) - if err != nil { - setStdLogPath = false - } - } - - if setStdLogPath { - svcConfig.Option["LogOutput"] = true - svcConfig.Option["LogDirectory"] = dir - } - } - } - - if runtime.GOOS == "windows" { - svcConfig.Option["OnFailure"] = "restart" + svcConfig, err := createServiceConfigForInstall() + if err != nil { + return err } ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() s, err := newSVC(newProgram(ctx, cancel), svcConfig) if err != nil { - cmd.PrintErrln(err) return err } - err = s.Install() - if err != nil { - cmd.PrintErrln(err) - return err + if err := s.Install(); err != nil { + return fmt.Errorf("install service: %w", err) } cmd.Println("Netbird service has been installed") @@ -93,27 +123,109 @@ var uninstallCmd = &cobra.Command{ Use: "uninstall", Short: "uninstalls Netbird service from system", RunE: func(cmd *cobra.Command, args []string) error { - SetFlagsFromEnvVars(rootCmd) + if err := setupServiceCommand(cmd); err != nil { + return err + } - cmd.SetOut(cmd.OutOrStdout()) + cfg, err := newSVCConfig() + if err != nil { + return fmt.Errorf("create service config: %w", err) + } - err := handleRebrand(cmd) + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + s, err := newSVC(newProgram(ctx, cancel), cfg) + if err != nil { + return err + } + + if err := s.Uninstall(); err != nil { + return fmt.Errorf("uninstall service: %w", err) + } + + cmd.Println("Netbird service has been uninstalled") + return nil + }, +} + +var reconfigureCmd = &cobra.Command{ + Use: "reconfigure", + Short: "reconfigures Netbird service with new settings", + Long: `Reconfigures the Netbird service with new settings without manual uninstall/install. +This command will temporarily stop the service, update its configuration, and restart it if it was running.`, + RunE: func(cmd *cobra.Command, args []string) error { + if err := setupServiceCommand(cmd); err != nil { + return err + } + + wasRunning, err := isServiceRunning() + if err != nil { + return fmt.Errorf("check service status: %w", err) + } + + svcConfig, err := createServiceConfigForInstall() if err != nil { return err } ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() - s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) + s, err := newSVC(newProgram(ctx, cancel), svcConfig) if err != nil { - return err + return fmt.Errorf("create service: %w", err) } - err = s.Uninstall() - if err != nil { - return err + if wasRunning { + cmd.Println("Stopping Netbird service...") + if err := s.Stop(); err != nil { + cmd.Printf("Warning: failed to stop service: %v\n", err) + } } - cmd.Println("Netbird service has been uninstalled") + + cmd.Println("Removing existing service configuration...") + if err := s.Uninstall(); err != nil { + return fmt.Errorf("uninstall existing service: %w", err) + } + + cmd.Println("Installing service with new configuration...") + if err := s.Install(); err != nil { + return fmt.Errorf("install service with new config: %w", err) + } + + if wasRunning { + cmd.Println("Starting Netbird service...") + if err := s.Start(); err != nil { + return fmt.Errorf("start service after reconfigure: %w", err) + } + cmd.Println("Netbird service has been reconfigured and started") + } else { + cmd.Println("Netbird service has been reconfigured") + } + return nil }, } + +func isServiceRunning() (bool, error) { + cfg, err := newSVCConfig() + if err != nil { + return false, err + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctx, cancel), cfg) + if err != nil { + return false, err + } + + status, err := s.Status() + if err != nil { + return false, nil + } + + return status == service.StatusRunning, nil +} diff --git a/client/cmd/service_test.go b/client/cmd/service_test.go new file mode 100644 index 000000000..308a6abd3 --- /dev/null +++ b/client/cmd/service_test.go @@ -0,0 +1,221 @@ +package cmd + +import ( + "context" + "fmt" + "runtime" + "testing" + "time" + + "github.com/kardianos/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + serviceStartTimeout = 5 * time.Second + serviceStopTimeout = 5 * time.Second +) + +// TestServiceLifecycle tests the complete service lifecycle +func TestServiceLifecycle(t *testing.T) { + originalServiceName := serviceName + serviceName = "netbird-test-" + fmt.Sprintf("%d", time.Now().Unix()) + defer func() { + serviceName = originalServiceName + }() + + configPath = "/tmp/netbird-test-config.json" + logLevel = "info" + daemonAddr = "unix:///tmp/netbird-test.sock" + + ctx := context.Background() + + t.Run("Install", func(t *testing.T) { + installCmd.SetContext(ctx) + err := installCmd.RunE(installCmd, []string{}) + require.NoError(t, err) + + cfg, err := newSVCConfig() + require.NoError(t, err) + + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + require.NoError(t, err) + + status, err := s.Status() + assert.NoError(t, err) + assert.NotEqual(t, service.StatusUnknown, status) + }) + + t.Run("Start", func(t *testing.T) { + startCmd.SetContext(ctx) + err := startCmd.RunE(startCmd, []string{}) + require.NoError(t, err) + + time.Sleep(serviceStartTimeout) + + running, err := isServiceRunning() + require.NoError(t, err) + assert.True(t, running) + }) + + t.Run("Restart", func(t *testing.T) { + restartCmd.SetContext(ctx) + err := restartCmd.RunE(restartCmd, []string{}) + require.NoError(t, err) + + time.Sleep(serviceStartTimeout) + + running, err := isServiceRunning() + require.NoError(t, err) + assert.True(t, running) + }) + + t.Run("Reconfigure", func(t *testing.T) { + originalLogLevel := logLevel + logLevel = "debug" + defer func() { + logLevel = originalLogLevel + }() + + reconfigureCmd.SetContext(ctx) + err := reconfigureCmd.RunE(reconfigureCmd, []string{}) + require.NoError(t, err) + + time.Sleep(serviceStartTimeout) + + running, err := isServiceRunning() + require.NoError(t, err) + assert.True(t, running) + }) + + t.Run("Stop", func(t *testing.T) { + stopCmd.SetContext(ctx) + err := stopCmd.RunE(stopCmd, []string{}) + require.NoError(t, err) + + time.Sleep(serviceStopTimeout) + + running, err := isServiceRunning() + require.NoError(t, err) + assert.False(t, running) + }) + + t.Run("Uninstall", func(t *testing.T) { + uninstallCmd.SetContext(ctx) + err := uninstallCmd.RunE(uninstallCmd, []string{}) + require.NoError(t, err) + + cfg, err := newSVCConfig() + require.NoError(t, err) + + ctxSvc, cancel := context.WithCancel(context.Background()) + defer cancel() + + s, err := newSVC(newProgram(ctxSvc, cancel), cfg) + require.NoError(t, err) + + _, err = s.Status() + assert.Error(t, err) + }) +} + +// TestServiceEnvVars tests environment variable parsing +func TestServiceEnvVars(t *testing.T) { + tests := []struct { + name string + envVars []string + expected map[string]string + expectErr bool + }{ + { + name: "Valid single env var", + envVars: []string{"LOG_LEVEL=debug"}, + expected: map[string]string{ + "LOG_LEVEL": "debug", + }, + }, + { + name: "Valid multiple env vars", + envVars: []string{"LOG_LEVEL=debug", "CUSTOM_VAR=value"}, + expected: map[string]string{ + "LOG_LEVEL": "debug", + "CUSTOM_VAR": "value", + }, + }, + { + name: "Env var with spaces", + envVars: []string{" KEY = value "}, + expected: map[string]string{ + "KEY": "value", + }, + }, + { + name: "Invalid format - no equals", + envVars: []string{"INVALID"}, + expectErr: true, + }, + { + name: "Invalid format - empty key", + envVars: []string{"=value"}, + expectErr: true, + }, + { + name: "Empty value is valid", + envVars: []string{"KEY="}, + expected: map[string]string{ + "KEY": "", + }, + }, + { + name: "Empty slice", + envVars: []string{}, + expected: map[string]string{}, + }, + { + name: "Empty string in slice", + envVars: []string{"", "KEY=value", ""}, + expected: map[string]string{"KEY": "value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseServiceEnvVars(tt.envVars) + + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestServiceConfigWithEnvVars tests service config creation with env vars +func TestServiceConfigWithEnvVars(t *testing.T) { + originalServiceName := serviceName + originalServiceEnvVars := serviceEnvVars + defer func() { + serviceName = originalServiceName + serviceEnvVars = originalServiceEnvVars + }() + + serviceName = "test-service" + serviceEnvVars = []string{"TEST_VAR=test_value", "ANOTHER_VAR=another_value"} + + cfg, err := newSVCConfig() + require.NoError(t, err) + + assert.Equal(t, "test-service", cfg.Name) + assert.Equal(t, "test_value", cfg.EnvVars["TEST_VAR"]) + assert.Equal(t, "another_value", cfg.EnvVars["ANOTHER_VAR"]) + + if runtime.GOOS == "linux" { + assert.Equal(t, "test-service", cfg.EnvVars["SYSTEMD_UNIT"]) + } +}