diff --git a/client/internal/config.go b/client/internal/config.go index 07f3f38e2..92a008bb2 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -242,13 +242,12 @@ func GetConfig(input ConfigInput) (*Config, error) { if _, err := os.Stat(input.ConfigPath); os.IsNotExist(err) { log.Infof("generating new config %s", input.ConfigPath) return createNewConfig(input) - } else { - // don't overwrite pre-shared key if we receive asterisks from UI - if *input.PreSharedKey == "**********" { - input.PreSharedKey = nil - } - return ReadConfig(input) } + + if isPreSharedKeyHidden(input.PreSharedKey) { + input.PreSharedKey = nil + } + return ReadConfig(input) } // generateKey generates a new Wireguard private key @@ -364,3 +363,11 @@ func isProviderConfigValid(config ProviderConfig) error { } return nil } + +// don't overwrite pre-shared key if we receive asterisks from UI +func isPreSharedKeyHidden(preSharedKey *string) bool { + if preSharedKey != nil && *preSharedKey == "**********" { + return true + } + return false +} diff --git a/client/internal/config_test.go b/client/internal/config_test.go index 3ca8a5213..5c2be50a0 100644 --- a/client/internal/config_test.go +++ b/client/internal/config_test.go @@ -85,3 +85,40 @@ func TestGetConfig(t *testing.T) { } assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL) } + +func TestHiddenPreSharedKey(t *testing.T) { + hidden := "**********" + samplePreSharedKey := "mysecretpresharedkey" + tests := []struct { + name string + preSharedKey *string + want string + }{ + {"nil", nil, ""}, + {"hidden", &hidden, ""}, + {"filled", &samplePreSharedKey, samplePreSharedKey}, + } + + // generate default cfg + cfgFile := filepath.Join(t.TempDir(), "config.json") + _, _ = GetConfig(ConfigInput{ + ConfigPath: cfgFile, + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := GetConfig(ConfigInput{ + ConfigPath: cfgFile, + PreSharedKey: tt.preSharedKey, + }) + + if err != nil { + t.Fatalf("failed to get cfg: %s", err) + } + + if cfg.PreSharedKey != tt.want { + t.Fatalf("invalid preshared key: '%s', expected: '%s' ", cfg.PreSharedKey, tt.want) + } + }) + } +}