diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go new file mode 100644 index 000000000..d84791ffd --- /dev/null +++ b/client/internal/profilemanager/config.go @@ -0,0 +1,523 @@ +package profilemanager + +import ( + "crypto/tls" + "fmt" + "net/url" + "os" + "path/filepath" + "reflect" + "runtime" + "slices" + "strings" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" + "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/util" +) + +const ( + // managementLegacyPortString is the port that was used before by the Management gRPC server. + // It is used for backward compatibility now. + // NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import + managementLegacyPortString = "33073" + // DefaultManagementURL points to the NetBird's cloud management endpoint + DefaultManagementURL = "https://api.netbird.io:443" + // oldDefaultManagementURL points to the NetBird's old cloud management endpoint + oldDefaultManagementURL = "https://api.wiretrustee.com:443" + // DefaultAdminURL points to NetBird's cloud management console + DefaultAdminURL = "https://app.netbird.io:443" +) + +var defaultInterfaceBlacklist = []string{ + iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", + "Tailscale", "tailscale", "docker", "veth", "br-", "lo", +} + +// ConfigInput carries configuration changes to the client +type ConfigInput struct { + ManagementURL string + AdminURL string + ConfigPath string + StateFilePath string + PreSharedKey *string + ServerSSHAllowed *bool + NATExternalIPs []string + CustomDNSAddress []byte + RosenpassEnabled *bool + RosenpassPermissive *bool + InterfaceName *string + WireguardPort *int + NetworkMonitor *bool + DisableAutoConnect *bool + ExtraIFaceBlackList []string + DNSRouteInterval *time.Duration + ClientCertPath string + ClientCertKeyPath string + + DisableClientRoutes *bool + DisableServerRoutes *bool + DisableDNS *bool + DisableFirewall *bool + BlockLANAccess *bool + BlockInbound *bool + + DisableNotifications *bool + + DNSLabels domain.List + + LazyConnectionEnabled *bool +} + +// Config Configuration type +type Config struct { + // Wireguard private key of local peer + PrivateKey string + PreSharedKey string + ManagementURL *url.URL + AdminURL *url.URL + WgIface string + WgPort int + NetworkMonitor *bool + IFaceBlackList []string + DisableIPv6Discovery bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed *bool + + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool + BlockLANAccess bool + BlockInbound bool + + DisableNotifications *bool + + DNSLabels domain.List + + // SSHKey is a private SSH key in a PEM format + SSHKey string + + // ExternalIP mappings, if different from the host interface IP + // + // External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP + // to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT + // mapping ExternalIP to host interface IP, or a NAT DMZ to host interface IP. + // + // A single mapping will take the form of: external[/internal] + // external (required): either the external IP address or "stun" to use STUN to determine the external IP address + // internal (optional): either the internal/interface IP address or an interface name + // + // examples: + // "12.34.56.78" => all interfaces IPs will be mapped to external IP of 12.34.56.78 + // "12.34.56.78/eth0" => IPv4 assigned to interface eth0 will be mapped to external IP of 12.34.56.78 + // "12.34.56.78/10.1.2.3" => interface IP 10.1.2.3 will be mapped to external IP of 12.34.56.78 + + NATExternalIPs []string + // CustomDNSAddress sets the DNS resolver listening address in format ip:port + CustomDNSAddress string + + // DisableAutoConnect determines whether the client should not start with the service + // it's set to false by default due to backwards compatibility + DisableAutoConnect bool + + // DNSRouteInterval is the interval in which the DNS routes are updated + DNSRouteInterval time.Duration + // Path to a certificate used for mTLS authentication + ClientCertPath string + + // Path to corresponding private key of ClientCertPath + ClientCertKeyPath string + + ClientCertKeyPair *tls.Certificate `json:"-"` + + LazyConnectionEnabled bool +} + +func getConfigDir() (string, error) { + configDir, err := os.UserConfigDir() + if err != nil { + return "", err + } + + configDir = filepath.Join(configDir, "netbird") + if _, err := os.Stat(configDir); os.IsNotExist(err) { + if err := os.MkdirAll(configDir, 0755); err != nil { + return "", err + } + } + + return configDir, nil +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) +} + +// createNewConfig creates a new config generating a new Wireguard key and saving to file +func createNewConfig(input ConfigInput) (*Config, error) { + config := &Config{ + // defaults to false only for new (post 0.26) configurations + ServerSSHAllowed: util.False(), + } + + if _, err := config.apply(input); err != nil { + return nil, err + } + + return config, nil +} + +func (config *Config) apply(input ConfigInput) (updated bool, err error) { + if config.ManagementURL == nil { + log.Infof("using default Management URL %s", DefaultManagementURL) + config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) + if err != nil { + return false, err + } + } + if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() { + log.Infof("new Management URL provided, updated to %#v (old value %#v)", + input.ManagementURL, config.ManagementURL.String()) + URL, err := parseURL("Management URL", input.ManagementURL) + if err != nil { + return false, err + } + config.ManagementURL = URL + updated = true + } else if config.ManagementURL == nil { + log.Infof("using default Management URL %s", DefaultManagementURL) + config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) + if err != nil { + return false, err + } + } + + if config.AdminURL == nil { + log.Infof("using default Admin URL %s", DefaultManagementURL) + config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) + if err != nil { + return false, err + } + } + if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() { + log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)", + input.AdminURL, config.AdminURL.String()) + newURL, err := parseURL("Admin Panel URL", input.AdminURL) + if err != nil { + return updated, err + } + config.AdminURL = newURL + updated = true + } + + if config.PrivateKey == "" { + log.Infof("generated new Wireguard key") + config.PrivateKey = generateKey() + updated = true + } + + if config.SSHKey == "" { + log.Infof("generated new SSH key") + pem, err := ssh.GeneratePrivateKey(ssh.ED25519) + if err != nil { + return false, err + } + config.SSHKey = string(pem) + updated = true + } + + if input.WireguardPort != nil && *input.WireguardPort != config.WgPort { + log.Infof("updating Wireguard port %d (old value %d)", + *input.WireguardPort, config.WgPort) + config.WgPort = *input.WireguardPort + updated = true + } else if config.WgPort == 0 { + config.WgPort = iface.DefaultWgPort + log.Infof("using default Wireguard port %d", config.WgPort) + updated = true + } + + if input.InterfaceName != nil && *input.InterfaceName != config.WgIface { + log.Infof("updating Wireguard interface %#v (old value %#v)", + *input.InterfaceName, config.WgIface) + config.WgIface = *input.InterfaceName + updated = true + } else if config.WgIface == "" { + config.WgIface = iface.WgInterfaceDefault + log.Infof("using default Wireguard interface %s", config.WgIface) + updated = true + } + + if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) { + log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])", + strings.Join(input.NATExternalIPs, " "), + strings.Join(config.NATExternalIPs, " ")) + config.NATExternalIPs = input.NATExternalIPs + updated = true + } + + if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey { + log.Infof("new pre-shared key provided, replacing old key") + config.PreSharedKey = *input.PreSharedKey + updated = true + } + + if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled { + log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled) + config.RosenpassEnabled = *input.RosenpassEnabled + updated = true + } + + if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive { + log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive) + config.RosenpassPermissive = *input.RosenpassPermissive + updated = true + } + + if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor { + log.Infof("switching Network Monitor to %t", *input.NetworkMonitor) + config.NetworkMonitor = input.NetworkMonitor + updated = true + } + + if config.NetworkMonitor == nil { + // enable network monitoring by default on windows and darwin clients + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + enabled := true + config.NetworkMonitor = &enabled + updated = true + } + } + + if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress { + log.Infof("updating custom DNS address %#v (old value %#v)", + string(input.CustomDNSAddress), config.CustomDNSAddress) + config.CustomDNSAddress = string(input.CustomDNSAddress) + updated = true + } + + if len(config.IFaceBlackList) == 0 { + log.Infof("filling in interface blacklist with defaults: [ %s ]", + strings.Join(defaultInterfaceBlacklist, " ")) + config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...) + updated = true + } + + if len(input.ExtraIFaceBlackList) > 0 { + for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) { + log.Infof("adding new entry to interface blacklist: %s", iFace) + config.IFaceBlackList = append(config.IFaceBlackList, iFace) + updated = true + } + } + + if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect { + if *input.DisableAutoConnect { + log.Infof("turning off automatic connection on startup") + } else { + log.Infof("enabling automatic connection on startup") + } + config.DisableAutoConnect = *input.DisableAutoConnect + updated = true + } + + if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed { + if *input.ServerSSHAllowed { + log.Infof("enabling SSH server") + } else { + log.Infof("disabling SSH server") + } + config.ServerSSHAllowed = input.ServerSSHAllowed + updated = true + } else if config.ServerSSHAllowed == nil { + if runtime.GOOS == "android" { + // default to disabled SSH on Android for security + log.Infof("setting SSH server to false by default on Android") + config.ServerSSHAllowed = util.False() + } else { + // enables SSH for configs from old versions to preserve backwards compatibility + log.Infof("falling back to enabled SSH server for pre-existing configuration") + config.ServerSSHAllowed = util.True() + } + updated = true + } + + if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { + log.Infof("updating DNS route interval to %s (old value %s)", + input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) + config.DNSRouteInterval = *input.DNSRouteInterval + updated = true + } else if config.DNSRouteInterval == 0 { + config.DNSRouteInterval = dynamic.DefaultInterval + log.Infof("using default DNS route interval %s", config.DNSRouteInterval) + updated = true + } + + if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes { + if *input.DisableClientRoutes { + log.Infof("disabling client routes") + } else { + log.Infof("enabling client routes") + } + config.DisableClientRoutes = *input.DisableClientRoutes + updated = true + } + + if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes { + if *input.DisableServerRoutes { + log.Infof("disabling server routes") + } else { + log.Infof("enabling server routes") + } + config.DisableServerRoutes = *input.DisableServerRoutes + updated = true + } + + if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS { + if *input.DisableDNS { + log.Infof("disabling DNS configuration") + } else { + log.Infof("enabling DNS configuration") + } + config.DisableDNS = *input.DisableDNS + updated = true + } + + if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall { + if *input.DisableFirewall { + log.Infof("disabling firewall configuration") + } else { + log.Infof("enabling firewall configuration") + } + config.DisableFirewall = *input.DisableFirewall + updated = true + } + + if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess { + if *input.BlockLANAccess { + log.Infof("blocking LAN access") + } else { + log.Infof("allowing LAN access") + } + config.BlockLANAccess = *input.BlockLANAccess + updated = true + } + + if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound { + if *input.BlockInbound { + log.Infof("blocking inbound connections") + } else { + log.Infof("allowing inbound connections") + } + config.BlockInbound = *input.BlockInbound + updated = true + } + + if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { + if *input.DisableNotifications { + log.Infof("disabling notifications") + } else { + log.Infof("enabling notifications") + } + config.DisableNotifications = input.DisableNotifications + updated = true + } + + if config.DisableNotifications == nil { + disabled := true + config.DisableNotifications = &disabled + log.Infof("setting notifications to disabled by default") + updated = true + } + + if input.ClientCertKeyPath != "" { + config.ClientCertKeyPath = input.ClientCertKeyPath + updated = true + } + + if input.ClientCertPath != "" { + config.ClientCertPath = input.ClientCertPath + updated = true + } + + if config.ClientCertPath != "" && config.ClientCertKeyPath != "" { + cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath) + if err != nil { + log.Error("Failed to load mTLS cert/key pair: ", err) + } else { + config.ClientCertKeyPair = &cert + log.Info("Loaded client mTLS cert/key pair") + } + } + + if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) { + log.Infof("updating DNS labels [ %s ] (old value: [ %s ])", + input.DNSLabels.SafeString(), + config.DNSLabels.SafeString()) + config.DNSLabels = input.DNSLabels + updated = true + } + + if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled { + log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled) + config.LazyConnectionEnabled = *input.LazyConnectionEnabled + updated = true + } + + return updated, nil +} + +// parseURL parses and validates a service URL +func parseURL(serviceName, serviceURL string) (*url.URL, error) { + parsedMgmtURL, err := url.ParseRequestURI(serviceURL) + if err != nil { + log.Errorf("failed parsing %s URL %s: [%s]", serviceName, serviceURL, err.Error()) + return nil, err + } + + if parsedMgmtURL.Scheme != "https" && parsedMgmtURL.Scheme != "http" { + return nil, fmt.Errorf( + "invalid %s URL provided %s. Supported format [http|https]://[host]:[port]", + serviceName, serviceURL) + } + + if parsedMgmtURL.Port() == "" { + switch parsedMgmtURL.Scheme { + case "https": + parsedMgmtURL.Host += ":443" + case "http": + parsedMgmtURL.Host += ":80" + default: + log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL) + } + } + + return parsedMgmtURL, err +} + +// generateKey generates a new Wireguard private key +func generateKey() string { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + panic(err) + } + return key.String() +} + +// don't overwrite pre-shared key if we receive asterisks from UI +func isPreSharedKeyHidden(preSharedKey *string) bool { + if preSharedKey != nil && *preSharedKey == "**********" { + return true + } + return false +} diff --git a/client/internal/profilemanager/error.go b/client/internal/profilemanager/error.go index cfc1332e5..d83fe5c1c 100644 --- a/client/internal/profilemanager/error.go +++ b/client/internal/profilemanager/error.go @@ -5,4 +5,5 @@ import "errors" var ( ErrProfileNotFound = errors.New("profile not found") ErrProfileAlreadyExists = errors.New("profile already exists") + ErrNoActiveProfile = errors.New("no active profile set") ) diff --git a/client/internal/profilemanager/profilemanager.go b/client/internal/profilemanager/profilemanager.go index 754981507..b0c56f312 100644 --- a/client/internal/profilemanager/profilemanager.go +++ b/client/internal/profilemanager/profilemanager.go @@ -1,9 +1,12 @@ package profilemanager import ( + "context" + "errors" "fmt" - "os" "path/filepath" + "strings" + "sync" "github.com/netbirdio/netbird/util" ) @@ -14,6 +17,8 @@ type Profile struct { } type ProfileManager struct { + mu sync.Mutex + activeProfile *Profile } func NewProfileManager() *ProfileManager { @@ -31,39 +36,80 @@ func (pm *ProfileManager) AddProfile(profile Profile) error { return ErrProfileAlreadyExists } -} - -func getConfigDir() (string, error) { - configDir, err := os.UserConfigDir() + cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath}) if err != nil { - return "", err + return fmt.Errorf("failed to create new config: %w", err) } - configDir = filepath.Join(configDir, "netbird") - if _, err := os.Stat(configDir); os.IsNotExist(err) { - if err := os.MkdirAll(configDir, 0755); err != nil { - return "", err - } + err = util.WriteJsonWithRestrictedPermission(context.Background(), profPath, cfg) + if err != nil { + return fmt.Errorf("failed to write profile config: %w", err) } - return configDir, nil + return nil } -func fileExists(path string) bool { - _, err := os.Stat(path) - return !os.IsNotExist(err) -} - -// createNewConfig creates a new config generating a new Wireguard key and saving to file -func createNewConfig(input ConfigInput) (*Config, error) { - config := &Config{ - // defaults to false only for new (post 0.26) configurations - ServerSSHAllowed: util.False(), +func (pm *ProfileManager) RemoveProfile(profileName string) error { + configDir, err := getConfigDir() + if err != nil { + return fmt.Errorf("failed to get config directory: %w", err) } - if _, err := config.apply(input); err != nil { - return nil, err + profPath := filepath.Join(configDir, profileName+".json") + if !fileExists(profPath) { + return ErrProfileNotFound } - return config, nil + activeProf, err := pm.GetActiveProfile() + if err != nil && !errors.Is(err, ErrNoActiveProfile) { + return fmt.Errorf("failed to get active profile: %w", err) + } + + if activeProf.Name == profileName { + return fmt.Errorf("cannot remove active profile: %s", profileName) + } + + err = util.RemoveJson(profPath) + if err != nil { + return fmt.Errorf("failed to remove profile config: %w", err) + } + return nil +} + +func (pm *ProfileManager) GetActiveProfile() (*Profile, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + if pm.activeProfile == nil { + return nil, ErrNoActiveProfile + } + + return pm.activeProfile, nil +} + +func (pm *ProfileManager) SetActiveProfile(profileName string) { + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.activeProfile = &Profile{Name: profileName} +} + +func (pm *ProfileManager) ListProfiles() ([]Profile, error) { + configDir, err := getConfigDir() + if err != nil { + return nil, fmt.Errorf("failed to get config directory: %w", err) + } + + files, err := util.ListFiles(configDir, "*.json") + if err != nil { + return nil, fmt.Errorf("failed to list profile files: %w", err) + } + + var profiles []Profile + for _, file := range files { + profileName := strings.TrimSuffix(filepath.Base(file), ".json") + profiles = append(profiles, Profile{Name: profileName}) + } + + return profiles, nil } diff --git a/util/file.go b/util/file.go index f7de7ede2..73ad05b18 100644 --- a/util/file.go +++ b/util/file.go @@ -9,6 +9,7 @@ import ( "io" "os" "path/filepath" + "sort" "strings" "text/template" @@ -200,6 +201,36 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// RemoveJson removes the specified JSON file if it exists +func RemoveJson(file string) error { + // Check if the file exists + if _, err := os.Stat(file); errors.Is(err, os.ErrNotExist) { + return nil // File does not exist, nothing to remove + } + + // Attempt to remove the file + if err := os.Remove(file); err != nil { + return fmt.Errorf("failed to remove JSON file %s: %w", file, err) + } + + return nil +} + +// ListFiles returns the full paths of all files in dir that match pattern. +// Pattern uses shell-style globbing (e.g. "*.json"). +func ListFiles(dir, pattern string) ([]string, error) { + // glob pattern like "/path/to/dir/*.json" + globPattern := filepath.Join(dir, pattern) + + matches, err := filepath.Glob(globPattern) + if err != nil { + return nil, err + } + + sort.Strings(matches) + return matches, nil +} + // ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { envVars := getEnvMap()