mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-04 22:10:56 +01:00
Merge branch 'main' into feature/relay-integration
This commit is contained in:
commit
4802b83ef9
51
.github/workflows/release.yml
vendored
51
.github/workflows/release.yml
vendored
@ -10,8 +10,10 @@ on:
|
||||
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.11"
|
||||
SIGN_PIPE_VER: "v0.0.12"
|
||||
GORELEASER_VER: "v1.14.1"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@ -23,6 +25,13 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
-
|
||||
@ -68,18 +77,18 @@ jobs:
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc amd64
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
|
||||
- name: Generate windows rsrc arm64
|
||||
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows rsrc arm
|
||||
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
|
||||
- name: Generate windows rsrc 386
|
||||
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
|
||||
-
|
||||
name: Run GoReleaser
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso 386
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
|
||||
- name: Generate windows syso arm
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
|
||||
- name: Generate windows syso arm64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
@ -121,6 +130,13 @@ jobs:
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
@ -151,10 +167,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc
|
||||
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
|
@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
||||
defer cancel()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
|
@ -121,7 +121,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
|
||||
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
|
||||
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
|
||||
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
|
||||
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
|
||||
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
|
||||
|
@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop, true
|
||||
}
|
||||
return rule.drop, true
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop, true
|
||||
}
|
||||
|
@ -15,6 +15,12 @@ type hostManager interface {
|
||||
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
Domains []string
|
||||
ServerIP string
|
||||
ServerPort int
|
||||
}
|
||||
|
||||
type HostDNSConfig struct {
|
||||
Domains []DomainConfig `json:"domains"`
|
||||
RouteAll bool `json:"routeAll"`
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
@ -18,7 +19,7 @@ import (
|
||||
const (
|
||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
|
||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
|
||||
keyServerAddresses = "ServerAddresses"
|
||||
@ -28,12 +29,12 @@ const (
|
||||
scutilPath = "/usr/sbin/scutil"
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
)
|
||||
|
||||
type systemConfigurator struct {
|
||||
// primaryServiceID primary interface in the system. AKA the interface with the default route
|
||||
primaryServiceID string
|
||||
createdKeys map[string]struct{}
|
||||
createdKeys map[string]struct{}
|
||||
systemDNSSettings SystemDNSSettings
|
||||
}
|
||||
|
||||
func newHostManager() (hostManager, error) {
|
||||
@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
var err error
|
||||
|
||||
if config.RouteAll {
|
||||
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns setup for all: %w", err)
|
||||
}
|
||||
} else if s.primaryServiceID != "" {
|
||||
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("remote key from system config: %w", err)
|
||||
}
|
||||
s.primaryServiceID = ""
|
||||
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
|
||||
}
|
||||
|
||||
// create a file for unclean shutdown detection
|
||||
if err := createUncleanShutdownIndicator(); err != nil {
|
||||
log.Errorf("failed to create unclean shutdown file: %s", err)
|
||||
@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
err = s.recordSystemDNSSettings(true)
|
||||
if err != nil {
|
||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
searchDomains = append(searchDomains, "\"\"")
|
||||
err = s.addLocalDNS()
|
||||
if err != nil {
|
||||
log.Infof("failed to enable split DNS")
|
||||
}
|
||||
}
|
||||
|
||||
for _, dConf := range config.Domains {
|
||||
if dConf.Disabled {
|
||||
continue
|
||||
@ -110,23 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreHostDNS() error {
|
||||
lines := ""
|
||||
for key := range s.createdKeys {
|
||||
lines += buildRemoveKeyOperation(key)
|
||||
keys := s.getRemovableKeysWithDefaults()
|
||||
for _, key := range keys {
|
||||
keyType := "search"
|
||||
if strings.Contains(key, matchSuffix) {
|
||||
keyType = "match"
|
||||
}
|
||||
log.Infof("removing %s domains from system", keyType)
|
||||
}
|
||||
if s.primaryServiceID != "" {
|
||||
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
|
||||
log.Infof("restoring DNS resolver configuration for system")
|
||||
}
|
||||
_, err := runSystemConfigCommand(wrapCommand(lines))
|
||||
if err != nil {
|
||||
log.Errorf("got an error while cleaning the system configuration: %s", err)
|
||||
return fmt.Errorf("clean system: %w", err)
|
||||
err := s.removeKeyFromSystemConfig(key)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove %s domains from system: %s", keyType, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := removeUncleanShutdownIndicator(); err != nil {
|
||||
@ -136,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
if len(s.createdKeys) == 0 {
|
||||
// return defaults for startup calls
|
||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.createdKeys))
|
||||
for key := range s.createdKeys {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
@ -148,6 +155,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addLocalDNS() error {
|
||||
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
|
||||
err := s.recordSystemDNSSettings(true)
|
||||
log.Errorf("Unable to get system DNS configuration")
|
||||
return err
|
||||
}
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
|
||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Info("Not enabling local DNS server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
|
||||
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
|
||||
return nil
|
||||
}
|
||||
|
||||
systemDNSSettings, err := s.getSystemDNSSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't get current DNS config: %w", err)
|
||||
}
|
||||
s.systemDNSSettings = systemDNSSettings
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
||||
primaryServiceKey, _, err := s.getPrimaryService()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||
}
|
||||
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
|
||||
line := buildCommandLine("show", dnsServiceKey, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
|
||||
b, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
|
||||
}
|
||||
|
||||
var dnsSettings SystemDNSSettings
|
||||
inSearchDomainsArray := false
|
||||
inServerAddressesArray := false
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(b))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
switch {
|
||||
case strings.HasPrefix(line, "DomainName :"):
|
||||
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
|
||||
case line == "SearchDomains : <array> {":
|
||||
inSearchDomainsArray = true
|
||||
continue
|
||||
case line == "ServerAddresses : <array> {":
|
||||
inServerAddressesArray = true
|
||||
continue
|
||||
case line == "}":
|
||||
inSearchDomainsArray = false
|
||||
inServerAddressesArray = false
|
||||
}
|
||||
|
||||
if inSearchDomainsArray {
|
||||
searchDomain := strings.Split(line, " : ")[1]
|
||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||
} else if inServerAddressesArray {
|
||||
address := strings.Split(line, " : ")[1]
|
||||
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
|
||||
dnsSettings.ServerIP = address
|
||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return dnsSettings, err
|
||||
}
|
||||
|
||||
// default to 53 port
|
||||
dnsSettings.ServerPort = 53
|
||||
|
||||
return dnsSettings, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
@ -194,23 +292,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
|
||||
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
|
||||
if err != nil || primaryServiceKey == "" {
|
||||
return fmt.Errorf("couldn't find the primary service key: %w", err)
|
||||
}
|
||||
|
||||
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
|
||||
s.primaryServiceID = primaryServiceKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
line := buildCommandLine("show", globalIPv4State, "")
|
||||
stdinCommands := wrapCommand(line)
|
||||
@ -239,19 +320,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
|
||||
return primaryService, router, nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
|
||||
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
|
||||
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
|
||||
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
|
||||
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
|
||||
stdinCommands := wrapCommand(addDomainCommand)
|
||||
_, err := runSystemConfigCommand(stdinCommands)
|
||||
if err != nil {
|
||||
return fmt.Errorf("applying dns setup, error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
|
||||
if err := s.restoreHostDNS(); err != nil {
|
||||
return fmt.Errorf("restoring dns via scutil: %w", err)
|
||||
|
@ -24,7 +24,7 @@ const (
|
||||
probeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
const testRecord = "."
|
||||
const testRecord = "com."
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
@ -42,6 +42,7 @@ type upstreamResolverBase struct {
|
||||
upstreamServers []string
|
||||
disabled bool
|
||||
failsCount atomic.Int32
|
||||
successCount atomic.Int32
|
||||
failsTillDeact int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
@ -124,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
u.successCount.Add(1)
|
||||
log.Tracef("took %s to query the upstream %s", t, upstream)
|
||||
|
||||
err = w.WriteMsg(rm)
|
||||
@ -172,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
default:
|
||||
}
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
|
||||
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
|
||||
return
|
||||
}
|
||||
|
||||
var success bool
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
@ -183,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(upstream)
|
||||
err := u.testNameserver(upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, err)
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
@ -224,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(upstream); err != nil {
|
||||
if err := u.testNameserver(upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
@ -244,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
|
||||
u.failsCount.Store(0)
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.disabled = false
|
||||
}
|
||||
@ -265,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server string) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
|
||||
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
@ -4,6 +4,7 @@ package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Debugf("Network monitor: closed routing socket")
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
|
@ -99,6 +99,11 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste
|
||||
return false
|
||||
}
|
||||
|
||||
if isSoftInterface(nexthop.Intf.Name) {
|
||||
log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
|
||||
return false
|
||||
}
|
||||
|
||||
unspec := getUnspecifiedPrefix(nexthop.IP)
|
||||
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
|
||||
|
||||
@ -119,7 +124,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
var defaultRoutes []string
|
||||
foundMatchingRoute := false
|
||||
|
||||
@ -128,7 +133,7 @@ func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []syst
|
||||
routeInfo := formatRouteInfo(r)
|
||||
defaultRoutes = append(defaultRoutes, routeInfo)
|
||||
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
|
||||
foundMatchingRoute = true
|
||||
log.Debugf("network monitor: found matching default route: %s", routeInfo)
|
||||
}
|
||||
@ -232,14 +237,18 @@ func stateFromInt(state uint8) string {
|
||||
}
|
||||
|
||||
func compareIntf(a, b *net.Interface) int {
|
||||
if a == nil && b == nil {
|
||||
switch {
|
||||
case a == nil && b == nil:
|
||||
return 0
|
||||
}
|
||||
if a == nil {
|
||||
case a == nil:
|
||||
return -1
|
||||
}
|
||||
if b == nil {
|
||||
case b == nil:
|
||||
return 1
|
||||
default:
|
||||
return a.Index - b.Index
|
||||
}
|
||||
return a.Index - b.Index
|
||||
}
|
||||
|
||||
func isSoftInterface(name string) bool {
|
||||
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@ -303,22 +304,33 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *clientNetwork) handleUpdate(update routesUpdate) {
|
||||
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
|
||||
isUpdateMapDifferent := false
|
||||
updateMap := make(map[route.ID]*route.Route)
|
||||
|
||||
for _, r := range update.routes {
|
||||
updateMap[r.ID] = r
|
||||
}
|
||||
|
||||
if len(c.routes) != len(updateMap) {
|
||||
isUpdateMapDifferent = true
|
||||
}
|
||||
|
||||
for id, r := range c.routes {
|
||||
_, found := updateMap[id]
|
||||
if !found {
|
||||
close(c.routePeersNotifiers[r.Peer])
|
||||
delete(c.routePeersNotifiers, r.Peer)
|
||||
isUpdateMapDifferent = true
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
|
||||
isUpdateMapDifferent = true
|
||||
}
|
||||
}
|
||||
|
||||
c.routes = updateMap
|
||||
return isUpdateMapDifferent
|
||||
}
|
||||
|
||||
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
|
||||
@ -345,13 +357,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
||||
|
||||
c.handleUpdate(update)
|
||||
// hash update somehow
|
||||
isTrueRouteUpdate := c.handleUpdate(update)
|
||||
|
||||
c.updateSerial = update.updateSerial
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
if isTrueRouteUpdate {
|
||||
log.Debug("Client network update contains different routes, recalculating routes")
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("Route update is not different, skipping route recalculation")
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
@ -50,7 +51,7 @@ type DefaultManager struct {
|
||||
statusRecorder *peer.Status
|
||||
wgInterface iface.IWGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
notifier *notifier.Notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
@ -65,7 +66,8 @@ func NewManager(
|
||||
initialRoutes []*route.Route,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
sysOps := systemops.NewSysOps(wgInterface)
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(wgInterface, notifier)
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
@ -77,7 +79,7 @@ func NewManager(
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
notifier: notifier,
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
@ -107,7 +109,7 @@ func NewManager(
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.clientRoutes(initialRoutes)
|
||||
dm.notifier.setInitialClientRoutes(cr)
|
||||
dm.notifier.SetInitialClientRoutes(cr)
|
||||
}
|
||||
return dm
|
||||
}
|
||||
@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.onNewRoutes(filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
}
|
||||
}
|
||||
|
||||
// SetRouteChangeListener set RouteListener for route change notifier
|
||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
|
||||
m.notifier.setListener(listener)
|
||||
m.notifier.SetListener(listener)
|
||||
}
|
||||
|
||||
// InitialRouteRange return the list of initial routes. It used by mobile systems
|
||||
func (m *DefaultManager) InitialRouteRange() []string {
|
||||
return m.notifier.getInitialRouteRanges()
|
||||
return m.notifier.GetInitialRouteRanges()
|
||||
}
|
||||
|
||||
// GetRouteSelector returns the route selector
|
||||
@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
|
||||
networks = m.routeSelector.FilterSelected(networks)
|
||||
|
||||
m.notifier.onNewRoutes(networks)
|
||||
m.notifier.OnNewRoutes(networks)
|
||||
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package routemanager
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -10,7 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
type Notifier struct {
|
||||
initialRouteRanges []string
|
||||
routeRanges []string
|
||||
|
||||
@ -18,17 +19,17 @@ type notifier struct {
|
||||
listenerMux sync.Mutex
|
||||
}
|
||||
|
||||
func newNotifier() *notifier {
|
||||
return ¬ifier{}
|
||||
func NewNotifier() *Notifier {
|
||||
return &Notifier{}
|
||||
}
|
||||
|
||||
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
|
||||
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
nets = append(nets, r.Network.String())
|
||||
@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
n.initialRouteRanges = nets
|
||||
}
|
||||
|
||||
func (n *notifier) onNewRoutes(idMap route.HAMap) {
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
if runtime.GOOS != "android" {
|
||||
return
|
||||
}
|
||||
newNets := make([]string, 0)
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *notifier) notify() {
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
newNets = append(newNets, prefix.String())
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
if n.listener == nil {
|
||||
@ -74,7 +101,7 @@ func (n *notifier) notify() {
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
func (n *notifier) hasDiff(a []string, b []string) bool {
|
||||
func (n *Notifier) hasDiff(a []string, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *notifier) getInitialRouteRanges() []string {
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return addIPv6RangeIfNeeded(n.initialRouteRanges)
|
||||
}
|
||||
|
@ -3,7 +3,9 @@ package systemops
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface iface.IWGIface
|
||||
// prefixes is tracking all the current added prefixes im memory
|
||||
// (this is used in iOS as all route updates require a full table update)
|
||||
//nolint
|
||||
prefixes map[netip.Prefix]struct{}
|
||||
//nolint
|
||||
mu sync.Mutex
|
||||
// notifier is used to notify the system of route changes (also used on mobile)
|
||||
notifier *notifier.Notifier
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface iface.IWGIface) *SysOps {
|
||||
func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
notifier: notifier,
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build ios || android
|
||||
//go:build android
|
||||
|
||||
package systemops
|
||||
|
@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
r := NewSysOps(nil)
|
||||
r := NewSysOps(nil, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
|
@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
|
||||
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
|
||||
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Tracef("Adding for prefix %s: %v", prefix, err)
|
||||
// These errors are not critical but also we should not track and try to remove the routes either.
|
||||
// These errors are not critical, but also we should not track and try to remove the routes either.
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
return nexthop, err
|
||||
@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIfac
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Check if the prefix is part of any local subnets
|
||||
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
|
||||
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIfac
|
||||
return exitNextHop, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
||||
localInterfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get local interfaces: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, intf := range localInterfaces {
|
||||
addrs, err := intf.Addrs()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
||||
continue
|
||||
}
|
||||
|
||||
if ipnet.Contains(prefix.Addr().AsSlice()) {
|
||||
return true, ipnet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
|
@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
_, _, err = r.SetupRouting(nil)
|
||||
require.NoError(t, err)
|
||||
@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
|
||||
assert.NoError(t, wgInterface.Close())
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
_, _, err := r.SetupRouting(nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
|
64
client/internal/routemanager/systemops/systemops_ios.go
Normal file
64
client/internal/routemanager/systemops/systemops_ios.go
Normal file
@ -0,0 +1,64 @@
|
||||
//go:build ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
r.notify()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.prefixes[prefix] = struct{}{}
|
||||
r.notify()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.prefixes, prefix)
|
||||
r.notify()
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
func (r *SysOps) notify() {
|
||||
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
|
||||
for prefix := range r.prefixes {
|
||||
prefixes = append(prefixes, prefix)
|
||||
}
|
||||
r.notifier.OnNewPrefixes(prefixes)
|
||||
}
|
@ -73,7 +73,7 @@ var testCases = []testCase{
|
||||
{
|
||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
||||
destination: "10.0.0.2:53",
|
||||
expectedSourceIP: "10.0.0.1",
|
||||
expectedSourceIP: "127.0.0.1",
|
||||
expectedDestPrefix: "10.0.0.0/8",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||
@ -110,7 +110,7 @@ var testCases = []testCase{
|
||||
{
|
||||
name: "To more specific route (local) without custom dialer via physical interface",
|
||||
destination: "127.0.10.2:53",
|
||||
expectedSourceIP: "10.0.0.1",
|
||||
expectedSourceIP: "127.0.0.1",
|
||||
expectedDestPrefix: "127.0.0.0/8",
|
||||
expectedNextHop: "0.0.0.0",
|
||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
||||
@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
|
||||
return combinedOutput
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
||||
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
|
||||
require.NoError(t, err)
|
||||
subnetMaskSize, _ := ipNet.Mask.Size()
|
||||
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
|
||||
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
|
||||
|
||||
// Wait for the IP address to be applied
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
err = waitForIPAddress(ctx, interfaceName, ip.String())
|
||||
require.NoError(t, err, "IP address not applied within timeout")
|
||||
|
||||
t.Cleanup(func() {
|
||||
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
|
||||
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
|
||||
})
|
||||
|
||||
return interfaceName
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
output, err := cmd.CombinedOutput()
|
||||
@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix
|
||||
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
|
||||
}
|
||||
|
||||
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
for _, ip := range ipAddresses {
|
||||
if strings.TrimSpace(ip) == expectedIPAddress {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
||||
var combined FindNetRouteOutput
|
||||
|
||||
@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
|
||||
addDummyRoute(t, "10.0.0.0/8")
|
||||
}
|
||||
|
||||
func addDummyRoute(t *testing.T, dstCIDR string) {
|
||||
t.Helper()
|
||||
|
||||
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
|
||||
|
||||
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
|
||||
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -47,6 +48,7 @@ type CustomLogger interface {
|
||||
type selectRoute struct {
|
||||
NetID string
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
}
|
||||
|
||||
@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
route := &selectRoute{
|
||||
NetID: string(id),
|
||||
Network: rt[0].Network,
|
||||
Domains: rt[0].Domains,
|
||||
Selected: routeSelector.IsSelected(id),
|
||||
}
|
||||
routes = append(routes, route)
|
||||
@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
|
||||
return iPrefix < jPrefix
|
||||
})
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
|
||||
|
||||
}
|
||||
|
||||
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
|
||||
var routeSelection []RoutesSelectionInfo
|
||||
for _, r := range routes {
|
||||
domainList := make([]DomainInfo, 0)
|
||||
for _, d := range r.Domains {
|
||||
domainResp := DomainInfo{
|
||||
Domain: d.SafeString(),
|
||||
}
|
||||
if prefixes, exists := resolvedDomains[d]; exists {
|
||||
var ipStrings []string
|
||||
for _, prefix := range prefixes {
|
||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||
}
|
||||
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
|
||||
}
|
||||
domainList = append(domainList, domainResp)
|
||||
}
|
||||
domainDetails := DomainDetails{items: domainList}
|
||||
routeSelection = append(routeSelection, RoutesSelectionInfo{
|
||||
ID: r.NetID,
|
||||
Network: r.Network.String(),
|
||||
Domains: &domainDetails,
|
||||
Selected: r.Selected,
|
||||
})
|
||||
}
|
||||
|
||||
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
|
||||
return &routeSelectionDetails, nil
|
||||
return &routeSelectionDetails
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(id string) error {
|
||||
|
@ -16,9 +16,25 @@ type RoutesSelectionDetails struct {
|
||||
type RoutesSelectionInfo struct {
|
||||
ID string
|
||||
Network string
|
||||
Domains *DomainDetails
|
||||
Selected bool
|
||||
}
|
||||
|
||||
type DomainCollection interface {
|
||||
Add(s DomainInfo) DomainCollection
|
||||
Get(i int) *DomainInfo
|
||||
Size() int
|
||||
}
|
||||
|
||||
type DomainDetails struct {
|
||||
items []DomainInfo
|
||||
}
|
||||
|
||||
type DomainInfo struct {
|
||||
Domain string
|
||||
ResolvedIPs string
|
||||
}
|
||||
|
||||
// Add new PeerInfo to the collection
|
||||
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
|
||||
array.items = append(array.items, s)
|
||||
@ -34,3 +50,16 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
|
||||
func (array RoutesSelectionDetails) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
|
||||
func (array DomainDetails) Add(s DomainInfo) DomainCollection {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
func (array DomainDetails) Get(i int) *DomainInfo {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
func (array DomainDetails) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
|
@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
}
|
||||
|
||||
// Down engine work in the daemon.
|
||||
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
return &proto.DownResponse{}, nil
|
||||
maxWaitTime := 5 * time.Second
|
||||
timeout := time.After(maxWaitTime)
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
|
||||
for {
|
||||
if !engine.IsWGIfaceUp() {
|
||||
return &proto.DownResponse{}, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return &proto.DownResponse{}, nil
|
||||
case <-timeout:
|
||||
return nil, fmt.Errorf("failed to shut down properly")
|
||||
default:
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the daemon status
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
@ -89,9 +90,17 @@ func _getInfo() string {
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var si sysinfo.SysInfo
|
||||
si.GetSysInfo()
|
||||
isascii := regexp.MustCompile("^[[:ascii:]]+$")
|
||||
serial := si.Chassis.Serial
|
||||
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
|
||||
serial = si.Product.Serial
|
||||
}
|
||||
return serial, si.Product.Name, si.Product.Vendor
|
||||
if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
|
||||
serial = si.Board.Serial
|
||||
}
|
||||
name := si.Product.Name
|
||||
if (!isascii.MatchString(name)) && si.Board.Name != "" {
|
||||
name = si.Board.Name
|
||||
}
|
||||
return serial, name, si.Product.Vendor
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
@ -34,6 +33,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -62,8 +62,25 @@ func main() {
|
||||
var errorMSG string
|
||||
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
|
||||
|
||||
tmpDir := "/tmp"
|
||||
if runtime.GOOS == "windows" {
|
||||
tmpDir = os.TempDir()
|
||||
}
|
||||
|
||||
var saveLogsInFile bool
|
||||
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if saveLogsInFile {
|
||||
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
|
||||
err := util.InitLog("trace", logFile)
|
||||
if err != nil {
|
||||
log.Errorf("error while initializing log: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
a := app.NewWithID("NetBird")
|
||||
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
|
||||
|
||||
@ -76,8 +93,12 @@ func main() {
|
||||
if showSettings || showRoutes {
|
||||
a.Run()
|
||||
} else {
|
||||
if err := checkPIDFile(); err != nil {
|
||||
log.Errorf("check PID file: %v", err)
|
||||
running, err := isAnotherProcessRunning()
|
||||
if err != nil {
|
||||
log.Errorf("error while checking process: %v", err)
|
||||
}
|
||||
if running {
|
||||
log.Warn("another process is running")
|
||||
return
|
||||
}
|
||||
client.setDefaultFonts()
|
||||
@ -861,104 +882,3 @@ func openURL(url string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// checkPIDFile exists and return error, or write new.
|
||||
func checkPIDFile() error {
|
||||
pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid")
|
||||
if piddata, err := os.ReadFile(pidFile); err == nil {
|
||||
if pid, err := strconv.Atoi(string(piddata)); err == nil {
|
||||
if process, err := os.FindProcess(pid); err == nil {
|
||||
if err := process.Signal(syscall.Signal(0)); err == nil {
|
||||
return fmt.Errorf("process already exists: %d", pid)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
|
||||
}
|
||||
|
||||
func (s *serviceClient) setDefaultFonts() {
|
||||
var (
|
||||
defaultFontPath string
|
||||
)
|
||||
|
||||
//TODO: Linux Multiple Language Support
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
|
||||
case "windows":
|
||||
fontPath := s.getWindowsFontFilePath()
|
||||
defaultFontPath = fontPath
|
||||
}
|
||||
|
||||
_, err := os.Stat(defaultFontPath)
|
||||
|
||||
if err == nil {
|
||||
os.Setenv("FYNE_FONT", defaultFontPath)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
|
||||
/*
|
||||
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
|
||||
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
|
||||
*/
|
||||
|
||||
var (
|
||||
fontFolder string = "C:/Windows/Fonts"
|
||||
fontMapping = map[string]string{
|
||||
"default": "Segoeui.ttf",
|
||||
"zh-CN": "Msyh.ttc",
|
||||
"am-ET": "Ebrima.ttf",
|
||||
"nirmala": "Nirmala.ttf",
|
||||
"chr-CHER-US": "Gadugi.ttf",
|
||||
"zh-HK": "Msjh.ttc",
|
||||
"zh-TW": "Msjh.ttc",
|
||||
"ja-JP": "Yugothm.ttc",
|
||||
"km-KH": "Leelawui.ttf",
|
||||
"ko-KR": "Malgun.ttf",
|
||||
"th-TH": "Leelawui.ttf",
|
||||
"ti-ET": "Ebrima.ttf",
|
||||
}
|
||||
nirMalaLang = []string{
|
||||
"as-IN",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"gu-IN",
|
||||
"hi-IN",
|
||||
"kn-IN",
|
||||
"kok-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"ne-NP",
|
||||
"or-IN",
|
||||
"pa-IN",
|
||||
"si-LK",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
}
|
||||
)
|
||||
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get Windows default language setting: %v", err)
|
||||
fontPath = path.Join(fontFolder, fontMapping["default"])
|
||||
return
|
||||
}
|
||||
defaultLanguage := strings.TrimSpace(string(output))
|
||||
|
||||
for _, lang := range nirMalaLang {
|
||||
if defaultLanguage == lang {
|
||||
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if font, ok := fontMapping[defaultLanguage]; ok {
|
||||
fontPath = path.Join(fontFolder, font)
|
||||
} else {
|
||||
fontPath = path.Join(fontFolder, fontMapping["default"])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
26
client/ui/font_bsd.go
Normal file
26
client/ui/font_bsd.go
Normal file
@ -0,0 +1,26 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
|
||||
|
||||
func (s *serviceClient) setDefaultFonts() {
|
||||
// TODO: add other bsd paths
|
||||
if runtime.GOOS != "darwin" {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := os.Stat(defaultFontPath); err != nil {
|
||||
log.Errorf("Failed to find default font file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
os.Setenv("FYNE_FONT", defaultFontPath)
|
||||
}
|
7
client/ui/font_linux.go
Normal file
7
client/ui/font_linux.go
Normal file
@ -0,0 +1,7 @@
|
||||
//go:build !386
|
||||
|
||||
package main
|
||||
|
||||
func (s *serviceClient) setDefaultFonts() {
|
||||
//TODO: Linux Multiple Language Support
|
||||
}
|
91
client/ui/font_windows.go
Normal file
91
client/ui/font_windows.go
Normal file
@ -0,0 +1,91 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func (s *serviceClient) setDefaultFonts() {
|
||||
defaultFontPath := s.getWindowsFontFilePath()
|
||||
|
||||
if _, err := os.Stat(defaultFontPath); err != nil {
|
||||
log.Errorf("Failed to find default font file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
os.Setenv("FYNE_FONT", defaultFontPath)
|
||||
}
|
||||
|
||||
func (s *serviceClient) getWindowsFontFilePath() string {
|
||||
var (
|
||||
fontFolder = "C:/Windows/Fonts"
|
||||
fontMapping = map[string]string{
|
||||
"default": "Segoeui.ttf",
|
||||
"zh-CN": "Msyh.ttc",
|
||||
"am-ET": "Ebrima.ttf",
|
||||
"nirmala": "Nirmala.ttf",
|
||||
"chr-CHER-US": "Gadugi.ttf",
|
||||
"zh-HK": "Msjh.ttc",
|
||||
"zh-TW": "Msjh.ttc",
|
||||
"ja-JP": "Yugothm.ttc",
|
||||
"km-KH": "Leelawui.ttf",
|
||||
"ko-KR": "Malgun.ttf",
|
||||
"th-TH": "Leelawui.ttf",
|
||||
"ti-ET": "Ebrima.ttf",
|
||||
}
|
||||
nirMalaLang = []string{
|
||||
"as-IN",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"gu-IN",
|
||||
"hi-IN",
|
||||
"kn-IN",
|
||||
"kok-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"ne-NP",
|
||||
"or-IN",
|
||||
"pa-IN",
|
||||
"si-LK",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
}
|
||||
)
|
||||
|
||||
// getUserDefaultLocaleName.Call() panics if the func is not found
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("Recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
kernel32 := windows.NewLazySystemDLL("kernel32.dll")
|
||||
getUserDefaultLocaleName := kernel32.NewProc("GetUserDefaultLocaleName")
|
||||
|
||||
buf := make([]uint16, 85) // LOCALE_NAME_MAX_LENGTH is usually 85
|
||||
r, _, err := getUserDefaultLocaleName.Call(uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)))
|
||||
// returns 0 on failure, err is always non-nil
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/winnls/nf-winnls-getuserdefaultlocalename
|
||||
if r == 0 {
|
||||
log.Errorf("GetUserDefaultLocaleName call failed: %v", err)
|
||||
return path.Join(fontFolder, fontMapping["default"])
|
||||
}
|
||||
|
||||
defaultLanguage := windows.UTF16ToString(buf)
|
||||
|
||||
for _, lang := range nirMalaLang {
|
||||
if defaultLanguage == lang {
|
||||
return path.Join(fontFolder, fontMapping["nirmala"])
|
||||
}
|
||||
}
|
||||
|
||||
if font, ok := fontMapping[defaultLanguage]; ok {
|
||||
return path.Join(fontFolder, font)
|
||||
}
|
||||
|
||||
return path.Join(fontFolder, fontMapping["default"])
|
||||
}
|
37
client/ui/process.go
Normal file
37
client/ui/process.go
Normal file
@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
func isAnotherProcessRunning() (bool, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
pid := os.Getpid()
|
||||
processName := strings.ToLower(filepath.Base(os.Args[0]))
|
||||
|
||||
for _, p := range processes {
|
||||
if int(p.Pid) == pid {
|
||||
continue
|
||||
}
|
||||
|
||||
runningProcessPath, err := p.Exe()
|
||||
// most errors are related to short-lived processes
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
26
client/ui/process_nonwindows.go
Normal file
26
client/ui/process_nonwindows.go
Normal file
@ -0,0 +1,26 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func isProcessOwnedByCurrentUser(p *process.Process) bool {
|
||||
currentUserID := os.Getuid()
|
||||
uids, err := p.Uids()
|
||||
if err != nil {
|
||||
log.Errorf("get process uids: %v", err)
|
||||
return false
|
||||
}
|
||||
for _, id := range uids {
|
||||
log.Debugf("checking process uid: %d", id)
|
||||
if int(id) == currentUserID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
24
client/ui/process_windows.go
Normal file
24
client/ui/process_windows.go
Normal file
@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func isProcessOwnedByCurrentUser(p *process.Process) bool {
|
||||
processUsername, err := p.Username()
|
||||
if err != nil {
|
||||
log.Errorf("get process username error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Errorf("get current user error: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return processUsername == currUser.Username
|
||||
}
|
10
go.mod
10
go.mod
@ -19,12 +19,12 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/sys v0.20.0
|
||||
golang.org/x/crypto v0.24.0
|
||||
golang.org/x/sys v0.21.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.64.0
|
||||
google.golang.org/grpc v1.64.1
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
@ -85,10 +85,10 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
|
||||
golang.org/x/net v0.25.0
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/term v0.20.0
|
||||
golang.org/x/term v0.21.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
|
20
go.sum
20
go.sum
@ -545,8 +545,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
@ -592,8 +592,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
|
||||
@ -650,8 +650,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@ -659,8 +659,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
||||
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
@ -719,8 +719,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
|
||||
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
|
||||
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
|
||||
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
|
||||
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
|
||||
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
|
||||
google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
|
||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
|
||||
|
@ -64,7 +64,7 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
t.wrapper = newDeviceWrapper(tunDevice)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
@ -49,7 +49,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
|
||||
t.device = device.NewDevice(
|
||||
t.wrapper,
|
||||
t.iceBind,
|
||||
device.NewLogger(device.LogLevelSilent, "[netbird] "),
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
err = t.assignAddr()
|
||||
|
@ -64,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
|
||||
|
||||
t.wrapper = newDeviceWrapper(tunDevice)
|
||||
log.Debug("Attaching to interface")
|
||||
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
|
||||
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
|
||||
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
|
||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||
|
@ -54,7 +54,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
|
||||
t.device = device.NewDevice(
|
||||
t.wrapper,
|
||||
t.iceBind,
|
||||
device.NewLogger(device.LogLevelSilent, "[netbird] "),
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = newWGUSPConfigurer(t.device, t.name)
|
||||
|
@ -57,7 +57,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
|
||||
t.device = device.NewDevice(
|
||||
t.wrapper,
|
||||
t.iceBind,
|
||||
device.NewLogger(device.LogLevelSilent, "[netbird] "),
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
err = t.assignAddr()
|
||||
|
@ -41,6 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
|
||||
}
|
||||
|
||||
func (t *tunDevice) Create() (wgConfigurer, error) {
|
||||
log.Info("create tun interface")
|
||||
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -52,7 +53,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
|
||||
t.device = device.NewDevice(
|
||||
t.wrapper,
|
||||
t.iceBind,
|
||||
device.NewLogger(device.LogLevelSilent, "[netbird] "),
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
|
||||
|
15
iface/wg_log.go
Normal file
15
iface/wg_log.go
Normal file
@ -0,0 +1,15 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func wgLogLevel() int {
|
||||
if os.Getenv("NB_WG_DEBUG") == "true" {
|
||||
return device.LogLevelVerbose
|
||||
} else {
|
||||
return device.LogLevelSilent
|
||||
}
|
||||
}
|
@ -2,7 +2,6 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
@ -11,15 +10,11 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
@ -51,26 +46,21 @@ type GrpcClient struct {
|
||||
|
||||
// NewClient creates a new client to Management service
|
||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
var conn *grpc.ClientConn
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(
|
||||
mgmCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
|
||||
if err != nil {
|
||||
log.Errorf("failed creating connection to Management Service %v", err)
|
||||
log.Errorf("failed creating connection to Management Service: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -326,25 +316,44 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
|
||||
defer cancel()
|
||||
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: loginReq,
|
||||
})
|
||||
|
||||
var resp *proto.EncryptedMessage
|
||||
operation := func() error {
|
||||
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: loginReq,
|
||||
})
|
||||
if err != nil {
|
||||
// retry only on context canceled
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return err
|
||||
}
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
|
||||
if err != nil {
|
||||
log.Errorf("failed to login to Management Service: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt registration message: %s", err)
|
||||
log.Errorf("failed to decrypt login response: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -69,6 +69,7 @@ type AccountManager interface {
|
||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
@ -95,6 +96,7 @@ type AccountManager interface {
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
|
||||
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
|
@ -746,3 +746,11 @@ func (s *FileStore) Close(ctx context.Context) error {
|
||||
func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||
return FileStoreEngine
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||
}
|
||||
|
@ -112,61 +112,85 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
|
||||
}
|
||||
|
||||
// SaveGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
var eventsToStore []func()
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
for _, newGroup := range newGroups {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
return err
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid duplicate groups only for the API issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
|
||||
// coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
oldGroup := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
oldGroup, exists := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
// the following snippet tracks the activity and stores the group events in the event store.
|
||||
// It has to happen after all the operations have been successfully performed.
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
if exists {
|
||||
|
||||
if oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
@ -175,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range removedPeers {
|
||||
@ -188,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// difference returns the elements in `a` that aren't in `b`.
|
||||
|
@ -40,6 +40,7 @@ type MockAccountManager struct {
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
@ -64,6 +65,7 @@ type MockAccountManager struct {
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s
|
||||
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
||||
}
|
||||
|
||||
// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error {
|
||||
if am.SaveGroupsFunc != nil {
|
||||
return am.SaveGroupsFunc(ctx, accountID, userID, groups)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
|
||||
}
|
||||
|
||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
if am.DeleteGroupFunc != nil {
|
||||
@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
|
||||
}
|
||||
|
||||
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
|
||||
if am.SaveOrAddUsersFunc != nil {
|
||||
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.DeleteUserFunc != nil {
|
||||
|
@ -274,10 +274,15 @@ func (s *SqlStore) GetInstallationID() string {
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||
Updates(peerCopy)
|
||||
|
||||
fieldsToUpdate := []string{
|
||||
"peer_status_last_seen", "peer_status_connected",
|
||||
"peer_status_login_expired", "peer_status_required_approval",
|
||||
}
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
@ -311,6 +316,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
usersToSave := make([]User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
}
|
||||
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
}
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
// It updates existing groups if a conflict occurs.
|
||||
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||
groupsToSave := make([]nbgroup.Group, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
group.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, *group)
|
||||
}
|
||||
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
return nil
|
||||
@ -662,11 +695,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -676,10 +705,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
|
||||
// NewPostgresqlStore creates a new Postgres store.
|
||||
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
PrepareStmt: true,
|
||||
})
|
||||
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -687,6 +713,14 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
|
||||
}
|
||||
|
||||
func getGormConfig() *gorm.Config {
|
||||
return &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: true,
|
||||
}
|
||||
}
|
||||
|
||||
// newPostgresStore initializes a new Postgres store.
|
||||
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
|
@ -41,11 +41,22 @@ func TestSqlite_NewStore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
|
||||
t.Skip("skip large test on non-linux OS due to environment restrictions")
|
||||
}
|
||||
t.Run("SQLite", func(t *testing.T) {
|
||||
store := newSqliteStore(t)
|
||||
runLargeTest(t, store)
|
||||
})
|
||||
// create store outside to have a better time counter for the test
|
||||
store := newPostgresqlStore(t)
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
runLargeTest(t, store)
|
||||
})
|
||||
}
|
||||
|
||||
store := newSqliteStore(t)
|
||||
func runLargeTest(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
groupALL, err := account.GetGroupAll()
|
||||
@ -54,7 +65,7 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||
}
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
const numPerAccount = 2000
|
||||
const numPerAccount = 6000
|
||||
for n := 0; n < numPerAccount; n++ {
|
||||
netIP := randomIPv4()
|
||||
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
|
||||
@ -362,7 +373,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
||||
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
@ -377,7 +388,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@ -41,6 +42,8 @@ type Store interface {
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
GetInstallationID() string
|
||||
|
@ -740,7 +740,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
|
||||
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
|
||||
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
|
||||
}
|
||||
@ -748,11 +748,31 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
|
||||
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
||||
if update == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if update == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(updatedUsers) == 0 {
|
||||
return nil, status.Errorf(status.Internal, "user was not updated")
|
||||
}
|
||||
|
||||
return updatedUsers[0], nil
|
||||
}
|
||||
|
||||
// SaveOrAddUsers updates existing users or adds new users to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
|
||||
if len(updates) == 0 {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
@ -769,144 +789,200 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
if !addIfNotExists {
|
||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist")
|
||||
updatedUsers := make([]*UserInfo, 0, len(updates))
|
||||
var (
|
||||
expiredPeers []*nbpeer.Peer
|
||||
eventsToStore []func()
|
||||
)
|
||||
|
||||
for _, update := range updates {
|
||||
if update == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
// when addIfNotExists is set to true the newUser will use all fields from the update input
|
||||
oldUser = update
|
||||
}
|
||||
|
||||
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
|
||||
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && update.Role != initiatorUser.Role {
|
||||
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
|
||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||
}
|
||||
|
||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||
}
|
||||
|
||||
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||
}
|
||||
|
||||
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
|
||||
return nil, status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
|
||||
}
|
||||
|
||||
transferedOwnerRole := false
|
||||
if initiatorUser.Role == UserRoleOwner && initiatorUserID != update.Id && update.Role == UserRoleOwner {
|
||||
newInitiatorUser := initiatorUser.Copy()
|
||||
newInitiatorUser.Role = UserRoleAdmin
|
||||
account.Users[initiatorUserID] = newInitiatorUser
|
||||
transferedOwnerRole = true
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and integration reference can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.Role = update.Role
|
||||
newUser.Blocked = update.Blocked
|
||||
// these two fields can't be set via API, only via direct call to the method
|
||||
newUser.Issued = update.Issued
|
||||
newUser.IntegrationReference = update.IntegrationReference
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
if !addIfNotExists {
|
||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
|
||||
}
|
||||
// when addIfNotExists is set to true, the newUser will use all fields from the update input
|
||||
oldUser = update
|
||||
}
|
||||
}
|
||||
newUser.AutoGroups = update.AutoGroups
|
||||
|
||||
account.Users[newUser.Id] = newUser
|
||||
|
||||
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||
// expire peers that belong to the user who's getting blocked
|
||||
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||
if err != nil {
|
||||
if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil {
|
||||
// only auto groups, revoked status, and integration reference can be updated for now
|
||||
newUser := oldUser.Copy()
|
||||
newUser.Role = update.Role
|
||||
newUser.Blocked = update.Blocked
|
||||
newUser.AutoGroups = update.AutoGroups
|
||||
// these two fields can't be set via API, only via direct call to the method
|
||||
newUser.Issued = update.Issued
|
||||
newUser.IntegrationReference = update.IntegrationReference
|
||||
|
||||
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
|
||||
account.Users[newUser.Id] = newUser
|
||||
|
||||
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||
// expire peers that belong to the user who's getting blocked
|
||||
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiredPeers = append(expiredPeers, blockedPeers...)
|
||||
}
|
||||
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
}
|
||||
|
||||
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updatedUsers = append(updatedUsers, updatedUserInfo)
|
||||
}
|
||||
|
||||
if len(expiredPeers) > 0 {
|
||||
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
} else {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
return updatedUsers, nil
|
||||
}
|
||||
|
||||
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
||||
if newUser.IsBlocked() {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
|
||||
})
|
||||
} else {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if oldUser.IsBlocked() != update.IsBlocked() {
|
||||
if update.IsBlocked() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil)
|
||||
switch {
|
||||
case transferredOwnerRole:
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
|
||||
})
|
||||
case oldUser.Role != newUser.Role:
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
})
|
||||
}
|
||||
|
||||
if newUser.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil)
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case transferedOwnerRole:
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil)
|
||||
case oldUser.Role != newUser.Role:
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
default:
|
||||
}
|
||||
|
||||
if update.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) && !newUser.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(ctx, newUser.Id, account)
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
|
||||
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
|
||||
newInitiatorUser := initiatorUser.Copy()
|
||||
newInitiatorUser.Role = UserRoleAdmin
|
||||
account.Users[initiatorUser.Id] = newInitiatorUser
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getUserInfo retrieves the UserInfo for a given User and Account.
|
||||
// If the AccountManager has a non-nil idpManager and the User is not a service user,
|
||||
// it will attempt to look up the UserData from the cache.
|
||||
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
|
||||
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(ctx, user.Id, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newUser.ToUserInfo(userData, account.Settings)
|
||||
return user.ToUserInfo(userData, account.Settings)
|
||||
}
|
||||
return newUser.ToUserInfo(nil, account.Settings)
|
||||
return user.ToUserInfo(nil, account.Settings)
|
||||
}
|
||||
|
||||
// validateUserUpdate validates the update operation for a user.
|
||||
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't change their role")
|
||||
}
|
||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
|
||||
}
|
||||
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "unable to block owner user")
|
||||
}
|
||||
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
|
||||
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
|
||||
}
|
||||
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
|
||||
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
|
||||
}
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
@ -937,7 +1013,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
|
||||
|
||||
userObj := account.Users[userID]
|
||||
|
||||
if account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
|
||||
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
|
||||
account.Domain = lowerDomain
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
|
@ -18,6 +18,8 @@ Flags:
|
||||
--letsencrypt-domain string a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS
|
||||
--port int Server port to listen on (e.g. 10000) (default 10000)
|
||||
--ssl-dir string server ssl directory location. *Required only for Let's Encrypt certificates. (default "/var/lib/netbird/")
|
||||
--cert-file string Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect
|
||||
--cert-key string Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect
|
||||
|
||||
Global Flags:
|
||||
--log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log")
|
||||
@ -90,6 +92,9 @@ The Signal Server exposes the following metrics in Prometheus format:
|
||||
- **registration_delay_milliseconds**: A Histogram metric that measures the time
|
||||
it took to register a peer in
|
||||
milliseconds.
|
||||
- **get_registration_delay_milliseconds**: A Histogram metric that measures the time
|
||||
it took to get a peer registration in
|
||||
milliseconds.
|
||||
- **messages_forwarded_total**: A Counter metric that counts the total number of
|
||||
messages forwarded between peers.
|
||||
- **message_forward_failures_total**: A Counter metric that counts the total
|
||||
|
@ -2,7 +2,6 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
@ -14,9 +13,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error {
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||
var conn *grpc.ClientConn
|
||||
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(
|
||||
sigCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}))
|
||||
|
||||
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to the signalling server %v", err)
|
||||
log.Errorf("failed to connect to the signalling server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
//todo send something??
|
||||
// todo send something??
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,15 +2,12 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -41,7 +38,8 @@ var (
|
||||
signalLetsencryptDomain string
|
||||
signalSSLDir string
|
||||
defaultSignalSSLDir string
|
||||
tlsEnabled bool
|
||||
signalCertFile string
|
||||
signalCertKey string
|
||||
|
||||
signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
||||
MinTime: 5 * time.Second,
|
||||
@ -56,12 +54,22 @@ var (
|
||||
})
|
||||
|
||||
runCmd = &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "start NetBird Signal Server daemon",
|
||||
Use: "run",
|
||||
Short: "start NetBird Signal Server daemon",
|
||||
SilenceUsage: true,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
err := util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// detect whether user specified a port
|
||||
userPort := cmd.Flag("port").Changed
|
||||
if signalLetsencryptDomain != "" {
|
||||
|
||||
tlsEnabled := false
|
||||
if signalLetsencryptDomain != "" || (signalCertFile != "" && signalCertKey != "") {
|
||||
tlsEnabled = true
|
||||
}
|
||||
|
||||
@ -77,33 +85,12 @@ var (
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
flag.Parse()
|
||||
|
||||
err := util.InitLog(logLevel, logFile)
|
||||
opts, certManager, err := getTLSConfigurations()
|
||||
if err != nil {
|
||||
log.Fatalf("failed initializing log %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if signalSSLDir == "" {
|
||||
oldPath := "/var/lib/wiretrustee"
|
||||
if migrateToNetbird(oldPath, defaultSignalSSLDir) {
|
||||
if err := cpDir(oldPath, defaultSignalSSLDir); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var opts []grpc.ServerOption
|
||||
var certManager *autocert.Manager
|
||||
if tlsEnabled {
|
||||
// Let's encrypt enabled -> generate certificate automatically
|
||||
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
opts = append(opts, grpc.Creds(transportCredentials))
|
||||
}
|
||||
|
||||
metricsServer := metrics.NewServer(metricsPort, "")
|
||||
metricsServer, err := metrics.NewServer(metricsPort, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup metrics: %v", err)
|
||||
}
|
||||
@ -124,7 +111,25 @@ var (
|
||||
}
|
||||
proto.RegisterSignalExchangeServer(grpcServer, srv)
|
||||
|
||||
grpcRootHandler := grpcHandlerFunc(grpcServer)
|
||||
|
||||
if certManager != nil {
|
||||
startServerWithCertManager(certManager, grpcRootHandler)
|
||||
}
|
||||
|
||||
var compatListener net.Listener
|
||||
var grpcListener net.Listener
|
||||
var httpListener net.Listener
|
||||
|
||||
// If certManager is configured and signalPort == 443, then the gRPC server has already been started
|
||||
if certManager == nil || signalPort != 443 {
|
||||
grpcListener, err = serveGRPC(grpcServer, signalPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
|
||||
}
|
||||
|
||||
if signalPort != 10000 {
|
||||
// The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal
|
||||
// are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000.
|
||||
@ -135,28 +140,6 @@ var (
|
||||
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
var grpcListener net.Listener
|
||||
var httpListener net.Listener
|
||||
if tlsEnabled {
|
||||
httpListener = certManager.Listener()
|
||||
if signalPort == 443 {
|
||||
// running gRPC and HTTP cert manager on the same port
|
||||
serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer)))
|
||||
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
|
||||
} else {
|
||||
serveHTTP(httpListener, certManager.HTTPHandler(nil))
|
||||
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
if signalPort != 443 || !tlsEnabled {
|
||||
grpcListener, err = serveGRPC(grpcServer, signalPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
|
||||
}
|
||||
|
||||
log.Infof("signal server version %s", version.NetbirdVersion())
|
||||
log.Infof("started Signal Service")
|
||||
|
||||
@ -190,6 +173,58 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
|
||||
var (
|
||||
err error
|
||||
certManager *autocert.Manager
|
||||
tlsConfig *tls.Config
|
||||
)
|
||||
|
||||
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
|
||||
log.Infof("running without TLS")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if signalLetsencryptDomain != "" {
|
||||
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
|
||||
if err != nil {
|
||||
return nil, certManager, err
|
||||
}
|
||||
tlsConfig = certManager.TLSConfig()
|
||||
log.Infof("setting up TLS with LetsEncrypt.")
|
||||
} else {
|
||||
if signalCertFile == "" || signalCertKey == "" {
|
||||
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||
}
|
||||
|
||||
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
|
||||
if err != nil {
|
||||
log.Errorf("cannot load TLS credentials: %v", err)
|
||||
return nil, certManager, err
|
||||
}
|
||||
log.Infof("setting up TLS with custom certificates.")
|
||||
}
|
||||
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
|
||||
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
|
||||
}
|
||||
|
||||
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
httpListener := certManager.Listener()
|
||||
if signalPort == 443 {
|
||||
// running gRPC and HTTP cert manager on the same port
|
||||
serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler))
|
||||
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
|
||||
} else {
|
||||
// Start the HTTP cert manager server separately
|
||||
serveHTTP(httpListener, certManager.HTTPHandler(nil))
|
||||
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
|
||||
@ -232,95 +267,29 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func cpFile(src, dst string) error {
|
||||
var err error
|
||||
var srcfd *os.File
|
||||
var dstfd *os.File
|
||||
var srcinfo os.FileInfo
|
||||
|
||||
if srcfd, err = os.Open(src); err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcfd.Close()
|
||||
|
||||
if dstfd, err = os.Create(dst); err != nil {
|
||||
return err
|
||||
}
|
||||
defer dstfd.Close()
|
||||
|
||||
if _, err = io.Copy(dstfd, srcfd); err != nil {
|
||||
return err
|
||||
}
|
||||
if srcinfo, err = os.Stat(src); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Chmod(dst, srcinfo.Mode())
|
||||
}
|
||||
|
||||
func copySymLink(source, dest string) error {
|
||||
link, err := os.Readlink(source)
|
||||
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
|
||||
// Load server's certificate and private key
|
||||
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Symlink(link, dest)
|
||||
}
|
||||
|
||||
func cpDir(src string, dst string) error {
|
||||
var err error
|
||||
var fds []os.DirEntry
|
||||
var srcinfo os.FileInfo
|
||||
|
||||
if srcinfo, err = os.Stat(src); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil {
|
||||
return err
|
||||
// NewDefaultAppMetrics the credentials and return it
|
||||
config := &tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
ClientAuth: tls.NoClientCert,
|
||||
NextProtos: []string{
|
||||
"h2", "http/1.1", // enable HTTP/2
|
||||
},
|
||||
}
|
||||
|
||||
if fds, err = os.ReadDir(src); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, fd := range fds {
|
||||
srcfp := path.Join(src, fd.Name())
|
||||
dstfp := path.Join(dst, fd.Name())
|
||||
|
||||
fileInfo, err := os.Stat(srcfp)
|
||||
if err != nil {
|
||||
log.Fatalf("Couldn't get fileInfo; %v", err)
|
||||
}
|
||||
|
||||
switch fileInfo.Mode() & os.ModeType {
|
||||
case os.ModeSymlink:
|
||||
if err = copySymLink(srcfp, dstfp); err != nil {
|
||||
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
|
||||
}
|
||||
case os.ModeDir:
|
||||
if err = cpDir(srcfp, dstfp); err != nil {
|
||||
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
|
||||
}
|
||||
default:
|
||||
if err = cpFile(srcfp, dstfp); err != nil {
|
||||
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateToNetbird(oldPath, newPath string) bool {
|
||||
_, errOld := os.Stat(oldPath)
|
||||
_, errNew := os.Stat(newPath)
|
||||
|
||||
if errors.Is(errOld, fs.ErrNotExist) || errNew == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
|
||||
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ type AppMetrics struct {
|
||||
Deregistrations metric.Int64Counter
|
||||
RegistrationFailures metric.Int64Counter
|
||||
RegistrationDelay metric.Float64Histogram
|
||||
GetRegistrationDelay metric.Float64Histogram
|
||||
|
||||
MessagesForwarded metric.Int64Counter
|
||||
MessageForwardFailures metric.Int64Counter
|
||||
@ -54,6 +55,12 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds",
|
||||
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
messagesForwarded, err := meter.Int64Counter("messages_forwarded_total")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -80,6 +87,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
|
||||
Deregistrations: deregistrations,
|
||||
RegistrationFailures: registrationFailures,
|
||||
RegistrationDelay: registrationDelay,
|
||||
GetRegistrationDelay: getRegistrationDelay,
|
||||
|
||||
MessagesForwarded: messagesForwarded,
|
||||
MessageForwardFailures: messageForwardFailures,
|
||||
|
@ -26,10 +26,10 @@ type Metrics struct {
|
||||
}
|
||||
|
||||
// NewServer initializes and returns a new Metrics instance
|
||||
func NewServer(port int, endpoint string) *Metrics {
|
||||
func NewServer(port int, endpoint string) (*Metrics, error) {
|
||||
exporter, err := prometheus.New()
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider := metric.NewMeterProvider(metric.WithReader(exporter))
|
||||
@ -57,7 +57,7 @@ func NewServer(port int, endpoint string) *Metrics {
|
||||
provider: provider,
|
||||
Endpoint: endpoint,
|
||||
Server: server,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Shutdown stops the metrics server
|
||||
|
@ -23,11 +23,17 @@ const (
|
||||
labelTypeError = "error"
|
||||
labelTypeNotConnected = "not_connected"
|
||||
labelTypeNotRegistered = "not_registered"
|
||||
labelTypeStream = "stream"
|
||||
labelTypeMessage = "message"
|
||||
|
||||
labelError = "error"
|
||||
labelErrorMissingId = "missing_id"
|
||||
labelErrorMissingMeta = "missing_meta"
|
||||
labelErrorFailedHeader = "failed_header"
|
||||
|
||||
labelRegistrationStatus = "status"
|
||||
labelRegistrationFound = "found"
|
||||
labelRegistrationNotFound = "not_found"
|
||||
)
|
||||
|
||||
// Server an instance of a Signal server
|
||||
@ -61,7 +67,11 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
|
||||
return nil, fmt.Errorf("peer %s is not registered", msg.Key)
|
||||
}
|
||||
|
||||
getRegistrationStart := time.Now()
|
||||
|
||||
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
|
||||
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
|
||||
start := time.Now()
|
||||
//forward the message to the target peer
|
||||
if err := dstPeer.Stream.Send(msg); err != nil {
|
||||
log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
|
||||
@ -69,9 +79,11 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
|
||||
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
|
||||
} else {
|
||||
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
|
||||
s.metrics.MessagesForwarded.Add(context.Background(), 1)
|
||||
}
|
||||
} else {
|
||||
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
|
||||
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
|
||||
//todo respond to the sender?
|
||||
|
||||
@ -118,28 +130,30 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
|
||||
log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
|
||||
|
||||
getRegistrationStart := time.Now()
|
||||
|
||||
// lookup the target peer where the message is going to
|
||||
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
|
||||
s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
|
||||
start := time.Now()
|
||||
//forward the message to the target peer
|
||||
if err := dstPeer.Stream.Send(msg); err != nil {
|
||||
log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err)
|
||||
//todo respond to the sender?
|
||||
|
||||
// in milliseconds
|
||||
s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6)
|
||||
s.metrics.MessagesForwarded.Add(stream.Context(), 1)
|
||||
} else {
|
||||
s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
|
||||
} else {
|
||||
// in milliseconds
|
||||
s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
|
||||
s.metrics.MessagesForwarded.Add(stream.Context(), 1)
|
||||
}
|
||||
} else {
|
||||
s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
|
||||
s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
|
||||
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey)
|
||||
//todo respond to the sender?
|
||||
|
||||
s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
|
||||
}
|
||||
}
|
||||
<-stream.Context().Done()
|
||||
|
@ -2,12 +2,18 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption {
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
|
||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxElapsedTime = 10 * time.Second
|
||||
b.Clock = backoff.SystemClock
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
|
||||
if tlsEnabled {
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("DialContext error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
@ -18,8 +19,9 @@ func InitLog(logLevel string, logPath string) error {
|
||||
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
|
||||
return err
|
||||
}
|
||||
customOutputs := []string{"console", "syslog"};
|
||||
|
||||
if logPath != "" && logPath != "console" {
|
||||
if logPath != "" && !slices.Contains(customOutputs, logPath) {
|
||||
lumberjackLogger := &lumberjack.Logger{
|
||||
// Log file absolute path, os agnostic
|
||||
Filename: filepath.ToSlash(logPath),
|
||||
@ -29,6 +31,8 @@ func InitLog(logLevel string, logPath string) error {
|
||||
Compress: true,
|
||||
}
|
||||
log.SetOutput(io.Writer(lumberjackLogger))
|
||||
} else if logPath == "syslog" {
|
||||
AddSyslogHook()
|
||||
}
|
||||
|
||||
if os.Getenv("NB_LOG_FORMAT") == "json" {
|
||||
|
20
util/syslog_nonwindows.go
Normal file
20
util/syslog_nonwindows.go
Normal file
@ -0,0 +1,20 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"log/syslog"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
lSyslog "github.com/sirupsen/logrus/hooks/syslog"
|
||||
)
|
||||
|
||||
func AddSyslogHook() {
|
||||
hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "")
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Failed creating syslog hook: %s", err)
|
||||
}
|
||||
log.AddHook(hook)
|
||||
}
|
6
util/syslog_windows.go
Normal file
6
util/syslog_windows.go
Normal file
@ -0,0 +1,6 @@
|
||||
package util
|
||||
|
||||
func AddSyslogHook() {
|
||||
// The syslog package is not available for Windows. This adapter is needed
|
||||
// to handle windows build.
|
||||
}
|
1
versioninfo.json
Normal file
1
versioninfo.json
Normal file
@ -0,0 +1 @@
|
||||
{}
|
Loading…
Reference in New Issue
Block a user