Merge branch 'main' into feature/relay-integration

This commit is contained in:
Zoltán Papp 2024-07-24 13:40:25 +02:00
commit 4802b83ef9
59 changed files with 1404 additions and 677 deletions

View File

@ -10,8 +10,10 @@ on:
env: env:
SIGN_PIPE_VER: "v0.0.11" SIGN_PIPE_VER: "v0.0.12"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@ -23,6 +25,13 @@ jobs:
env: env:
flags: "" flags: ""
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
@ -68,18 +77,18 @@ jobs:
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install rsrc - name: Install goversioninfo
run: go install github.com/akavel/rsrc@v0.10.2 run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows rsrc amd64 - name: Generate windows syso 386
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_386.syso
- name: Generate windows rsrc arm64 - name: Generate windows syso arm
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm.syso
- name: Generate windows rsrc arm - name: Generate windows syso arm64
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_arm64.syso
- name: Generate windows rsrc 386 - name: Generate windows syso amd64
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/resources_windows_amd64.syso
-
name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
@ -121,6 +130,13 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
@ -151,10 +167,11 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install rsrc - name: Install goversioninfo
run: go install github.com/akavel/rsrc@v0.10.2 run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows rsrc - name: Generate windows syso amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-build ${{ github.run_id }} -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -product-version ${{ steps.semver_parser.outputs.fullversion }}.${{ github.run_id }} -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:

View File

@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@ -121,7 +121,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")

View File

@ -337,7 +337,6 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true return rule.drop, true
} }
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.drop, true
} }

View File

@ -15,6 +15,12 @@ type hostManager interface {
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type SystemDNSSettings struct {
Domains []string
ServerIP string
ServerPort int
}
type HostDNSConfig struct { type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"` Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"` RouteAll bool `json:"routeAll"`

View File

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
@ -18,7 +19,7 @@ import (
const ( const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4" globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS" primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains" keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses" keyServerAddresses = "ServerAddresses"
@ -28,12 +29,12 @@ const (
scutilPath = "/usr/sbin/scutil" scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search" searchSuffix = "Search"
matchSuffix = "Match" matchSuffix = "Match"
localSuffix = "Local"
) )
type systemConfigurator struct { type systemConfigurator struct {
// primaryServiceID primary interface in the system. AKA the interface with the default route createdKeys map[string]struct{}
primaryServiceID string systemDNSSettings SystemDNSSettings
createdKeys map[string]struct{}
} }
func newHostManager() (hostManager, error) { func newHostManager() (hostManager, error) {
@ -49,20 +50,6 @@ func (s *systemConfigurator) supportCustomPort() bool {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil {
return fmt.Errorf("add dns setup for all: %w", err)
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return fmt.Errorf("remote key from system config: %w", err)
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
}
// create a file for unclean shutdown detection // create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil { if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err) log.Errorf("failed to create unclean shutdown file: %s", err)
@ -73,6 +60,19 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string matchDomains []string
) )
err = s.recordSystemDNSSettings(true)
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
}
}
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
if dConf.Disabled { if dConf.Disabled {
continue continue
@ -110,23 +110,17 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
} }
func (s *systemConfigurator) restoreHostDNS() error { func (s *systemConfigurator) restoreHostDNS() error {
lines := "" keys := s.getRemovableKeysWithDefaults()
for key := range s.createdKeys { for _, key := range keys {
lines += buildRemoveKeyOperation(key)
keyType := "search" keyType := "search"
if strings.Contains(key, matchSuffix) { if strings.Contains(key, matchSuffix) {
keyType = "match" keyType = "match"
} }
log.Infof("removing %s domains from system", keyType) log.Infof("removing %s domains from system", keyType)
} err := s.removeKeyFromSystemConfig(key)
if s.primaryServiceID != "" { if err != nil {
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) log.Errorf("failed to remove %s domains from system: %s", keyType, err)
log.Infof("restoring DNS resolver configuration for system") }
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return fmt.Errorf("clean system: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil { if err := removeUncleanShutdownIndicator(); err != nil {
@ -136,6 +130,19 @@ func (s *systemConfigurator) restoreHostDNS() error {
return nil return nil
} }
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
for key := range s.createdKeys {
keys = append(keys, key)
}
return keys
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key) line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line)) _, err := runSystemConfigCommand(wrapCommand(line))
@ -148,6 +155,97 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
return nil return nil
} }
func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration")
return err
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server")
}
return nil
}
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil
}
systemDNSSettings, err := s.getSystemDNSSettings()
if err != nil {
return fmt.Errorf("couldn't get current DNS config: %w", err)
}
s.systemDNSSettings = systemDNSSettings
return nil
}
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
primaryServiceKey, _, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
}
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
line := buildCommandLine("show", dnsServiceKey, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
}
var dnsSettings SystemDNSSettings
inSearchDomainsArray := false
inServerAddressesArray := false
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
switch {
case strings.HasPrefix(line, "DomainName :"):
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
case line == "SearchDomains : <array> {":
inSearchDomainsArray = true
continue
case line == "ServerAddresses : <array> {":
inServerAddressesArray = true
continue
case line == "}":
inSearchDomainsArray = false
inServerAddressesArray = false
}
if inSearchDomainsArray {
searchDomain := strings.Split(line, " : ")[1]
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
dnsSettings.ServerIP = address
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
if err := scanner.Err(); err != nil {
return dnsSettings, err
}
// default to 53 port
dnsSettings.ServerPort = 53
return dnsSettings, nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {
@ -194,23 +292,6 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
return nil return nil
} }
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key: %w", err)
}
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string, error) { func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "") line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line) stdinCommands := wrapCommand(line)
@ -239,19 +320,6 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil return primaryService, router, nil
} }
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err) return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@ -24,7 +24,7 @@ const (
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
const testRecord = "." const testRecord = "com."
type upstreamClient interface { type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
@ -42,6 +42,7 @@ type upstreamResolverBase struct {
upstreamServers []string upstreamServers []string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32 failsTillDeact int32
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
@ -124,6 +125,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream) log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm) err = w.WriteMsg(rm)
@ -172,6 +174,11 @@ func (u *upstreamResolverBase) probeAvailability() {
default: default:
} }
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
return
}
var success bool var success bool
var mu sync.Mutex var mu sync.Mutex
var wg sync.WaitGroup var wg sync.WaitGroup
@ -183,7 +190,7 @@ func (u *upstreamResolverBase) probeAvailability() {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := u.testNameserver(upstream) err := u.testNameserver(upstream, 500*time.Millisecond)
if err != nil { if err != nil {
errors = multierror.Append(errors, err) errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err) log.Warnf("probing upstream nameserver %s: %s", upstream, err)
@ -224,7 +231,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream); err != nil { if err := u.testNameserver(upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err) log.Tracef("upstream check for %s: %s", upstream, err)
} else { } else {
// at least one upstream server is available, stop probing // at least one upstream server is available, stop probing
@ -244,6 +251,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
} }
@ -265,13 +273,14 @@ func (u *upstreamResolverBase) disable(err error) {
} }
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err) u.deactivate(err)
u.disabled = true u.disabled = true
go u.waitUntilResponse() go u.waitUntilResponse()
} }
func (u *upstreamResolverBase) testNameserver(server string) error { func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

View File

@ -4,6 +4,7 @@ package networkmonitor
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"syscall" "syscall"
"unsafe" "unsafe"
@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
} }
defer func() { defer func() {
if err := unix.Close(fd); err != nil { err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err) log.Errorf("Network monitor: failed to close routing socket: %v", err)
} }
}() }()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
if err != nil { if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err) if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue continue
} }
if n < unix.SizeofRtMsghdr { if n < unix.SizeofRtMsghdr {

View File

@ -99,6 +99,11 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste
return false return false
} }
if isSoftInterface(nexthop.Intf.Name) {
log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP) unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
@ -119,7 +124,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
return netip.PrefixFrom(netip.IPv4Unspecified(), 0) return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
} }
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string var defaultRoutes []string
foundMatchingRoute := false foundMatchingRoute := false
@ -128,7 +133,7 @@ func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []syst
routeInfo := formatRouteInfo(r) routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo) defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 { if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
foundMatchingRoute = true foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo) log.Debugf("network monitor: found matching default route: %s", routeInfo)
} }
@ -232,14 +237,18 @@ func stateFromInt(state uint8) string {
} }
func compareIntf(a, b *net.Interface) int { func compareIntf(a, b *net.Interface) int {
if a == nil && b == nil { switch {
case a == nil && b == nil:
return 0 return 0
} case a == nil:
if a == nil {
return -1 return -1
} case b == nil:
if b == nil {
return 1 return 1
default:
return a.Index - b.Index
} }
return a.Index - b.Index }
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
} }

View File

@ -3,6 +3,7 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"reflect"
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@ -303,22 +304,33 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
}() }()
} }
func (c *clientNetwork) handleUpdate(update routesUpdate) { func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
isUpdateMapDifferent := false
updateMap := make(map[route.ID]*route.Route) updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes { for _, r := range update.routes {
updateMap[r.ID] = r updateMap[r.ID] = r
} }
if len(c.routes) != len(updateMap) {
isUpdateMapDifferent = true
}
for id, r := range c.routes { for id, r := range c.routes {
_, found := updateMap[id] _, found := updateMap[id]
if !found { if !found {
close(c.routePeersNotifiers[r.Peer]) close(c.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer) delete(c.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true
continue
}
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
isUpdateMapDifferent = true
} }
} }
c.routes = updateMap c.routes = updateMap
return isUpdateMapDifferent
} }
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events. // peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
@ -345,13 +357,19 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
log.Debugf("Received a new client network route update for [%v]", c.handler) log.Debugf("Received a new client network route update for [%v]", c.handler)
c.handleUpdate(update) // hash update somehow
isTrueRouteUpdate := c.handleUpdate(update)
c.updateSerial = update.updateSerial c.updateSerial = update.updateSerial
err := c.recalculateRouteAndUpdatePeerAndSystem() if isTrueRouteUpdate {
if err != nil { log.Debug("Client network update contains different routes, recalculating routes")
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err) err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
} else {
log.Debug("Route update is not different, skipping route recalculation")
} }
c.startPeersStatusChangeWatcher() c.startPeersStatusChangeWatcher()

View File

@ -16,6 +16,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
@ -50,7 +51,7 @@ type DefaultManager struct {
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.IWGIface wgInterface iface.IWGIface
pubKey string pubKey string
notifier *notifier notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
@ -65,7 +66,8 @@ func NewManager(
initialRoutes []*route.Route, initialRoutes []*route.Route,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface) notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
@ -77,7 +79,7 @@ func NewManager(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: newNotifier(), notifier: notifier,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@ -107,7 +109,7 @@ func NewManager(
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes) cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr) dm.notifier.SetInitialClientRoutes(cr)
} }
return dm return dm
} }
@ -186,7 +188,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes)
if m.serverRouter != nil { if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
@ -199,14 +201,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
} }
} }
// SetRouteChangeListener set RouteListener for route change notifier // SetRouteChangeListener set RouteListener for route change Notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) { func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.setListener(listener) m.notifier.SetListener(listener)
} }
// InitialRouteRange return the list of initial routes. It used by mobile systems // InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string { func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.getInitialRouteRanges() return m.notifier.GetInitialRouteRanges()
} }
// GetRouteSelector returns the route selector // GetRouteSelector returns the route selector
@ -226,7 +228,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
networks = m.routeSelector.FilterSelected(networks) networks = m.routeSelector.FilterSelected(networks)
m.notifier.onNewRoutes(networks) m.notifier.OnNewRoutes(networks)
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)

View File

@ -1,6 +1,7 @@
package routemanager package notifier
import ( import (
"net/netip"
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
@ -10,7 +11,7 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type notifier struct { type Notifier struct {
initialRouteRanges []string initialRouteRanges []string
routeRanges []string routeRanges []string
@ -18,17 +19,17 @@ type notifier struct {
listenerMux sync.Mutex listenerMux sync.Mutex
} }
func newNotifier() *notifier { func NewNotifier() *Notifier {
return &notifier{} return &Notifier{}
} }
func (n *notifier) setListener(listener listener.NetworkChangeListener) { func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock() n.listenerMux.Lock()
defer n.listenerMux.Unlock() defer n.listenerMux.Unlock()
n.listener = listener n.listener = listener
} }
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
nets = append(nets, r.Network.String()) nets = append(nets, r.Network.String())
@ -37,7 +38,10 @@ func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRanges = nets n.initialRouteRanges = nets
} }
func (n *notifier) onNewRoutes(idMap route.HAMap) { func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
newNets := make([]string, 0) newNets := make([]string, 0)
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {
@ -62,7 +66,30 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
n.notify() n.notify()
} }
func (n *notifier) notify() { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock() n.listenerMux.Lock()
defer n.listenerMux.Unlock() defer n.listenerMux.Unlock()
if n.listener == nil { if n.listener == nil {
@ -74,7 +101,7 @@ func (n *notifier) notify() {
}(n.listener) }(n.listener)
} }
func (n *notifier) hasDiff(a []string, b []string) bool { func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) { if len(a) != len(b) {
return true return true
} }
@ -86,7 +113,7 @@ func (n *notifier) hasDiff(a []string, b []string) bool {
return false return false
} }
func (n *notifier) getInitialRouteRanges() []string { func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges) return addIPv6RangeIfNeeded(n.initialRouteRanges)
} }

View File

@ -3,7 +3,9 @@ package systemops
import ( import (
"net" "net"
"net/netip" "net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@ -18,10 +20,19 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct { type SysOps struct {
refCounter *ExclusionCounter refCounter *ExclusionCounter
wgInterface iface.IWGIface wgInterface iface.IWGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
prefixes map[netip.Prefix]struct{}
//nolint
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
} }
func NewSysOps(wgInterface iface.IWGIface) *SysOps { func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier,
} }
} }

View File

@ -1,4 +1,4 @@
//go:build ios || android //go:build android
package systemops package systemops

View File

@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"} intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil) r := NewSysOps(nil, nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {

View File

@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop) nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) { if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err) log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either. // These errors are not critical, but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore return nexthop, refcounter.ErrIgnore
} }
return nexthop, err return nexthop, err
@ -135,6 +135,11 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIfac
return Nexthop{}, vars.ErrRouteNotAllowed return Nexthop{}, vars.ErrRouteNotAllowed
} }
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route // Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr) nexthop, err := GetNextHop(addr)
if err != nil { if err != nil {
@ -167,6 +172,36 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIfac
return exitNextHop, nil return exitNextHop, nil
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
localInterfaces, err := net.Interfaces()
if err != nil {
log.Errorf("Failed to get local interfaces: %v", err)
return false, nil
}
for _, intf := range localInterfaces {
addrs, err := intf.Addrs()
if err != nil {
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
log.Errorf("Failed to convert address to IPNet: %v", addr)
continue
}
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
}
}
return false, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route // in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {

View File

@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface) r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil) _, _, err = r.SetupRouting(nil)
require.NoError(t, err) require.NoError(t, err)
@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface) r := NewSysOps(wgInterface, nil)
// Prepare the environment // Prepare the environment
if testCase.preExistingPrefix.IsValid() { if testCase.preExistingPrefix.IsValid() {
@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close()) assert.NoError(t, wgInterface.Close())
}) })
r := NewSysOps(wgInterface) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil) _, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {

View File

@ -0,0 +1,64 @@
//go:build ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
r.notify()
return nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes[prefix] = struct{}{}
r.notify()
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.prefixes, prefix)
r.notify()
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}

View File

@ -73,7 +73,7 @@ var testCases = []testCase{
{ {
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53", destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1", expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8", expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1", expectedInterface: "Loopback Pseudo-Interface 1",
@ -110,7 +110,7 @@ var testCases = []testCase{
{ {
name: "To more specific route (local) without custom dialer via physical interface", name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53", destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1", expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8", expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0", expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1", expectedInterface: "Loopback Pseudo-Interface 1",
@ -181,31 +181,6 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
return combinedOutput return combinedOutput
} }
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) { func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json") cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
@ -231,30 +206,6 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
} }
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput var combined FindNetRouteOutput
@ -285,5 +236,25 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
func setupDummyInterfacesAndRoutes(t *testing.T) { func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper() t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") addDummyRoute(t, "10.0.0.0/8")
}
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
}
})
} }

View File

@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -47,6 +48,7 @@ type CustomLogger interface {
type selectRoute struct { type selectRoute struct {
NetID string NetID string
Network netip.Prefix Network netip.Prefix
Domains domain.List
Selected bool Selected bool
} }
@ -279,6 +281,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
route := &selectRoute{ route := &selectRoute{
NetID: string(id), NetID: string(id),
Network: rt[0].Network, Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id), Selected: routeSelector.IsSelected(id),
} }
routes = append(routes, route) routes = append(routes, route)
@ -299,17 +302,40 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return iPrefix < jPrefix return iPrefix < jPrefix
}) })
resolvedDomains := c.recorder.GetResolvedDomainsStates()
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo var routeSelection []RoutesSelectionInfo
for _, r := range routes { for _, r := range routes {
domainList := make([]DomainInfo, 0)
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
}
domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{ routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID, ID: r.NetID,
Network: r.Network.String(), Network: r.Network.String(),
Domains: &domainDetails,
Selected: r.Selected, Selected: r.Selected,
}) })
} }
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection} routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails, nil return &routeSelectionDetails
} }
func (c *Client) SelectRoute(id string) error { func (c *Client) SelectRoute(id string) error {

View File

@ -16,9 +16,25 @@ type RoutesSelectionDetails struct {
type RoutesSelectionInfo struct { type RoutesSelectionInfo struct {
ID string ID string
Network string Network string
Domains *DomainDetails
Selected bool Selected bool
} }
type DomainCollection interface {
Add(s DomainInfo) DomainCollection
Get(i int) *DomainInfo
Size() int
}
type DomainDetails struct {
items []DomainInfo
}
type DomainInfo struct {
Domain string
ResolvedIPs string
}
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails { func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
array.items = append(array.items, s) array.items = append(array.items, s)
@ -34,3 +50,16 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
func (array RoutesSelectionDetails) Size() int { func (array RoutesSelectionDetails) Size() int {
return len(array.items) return len(array.items)
} }
func (array DomainDetails) Add(s DomainInfo) DomainCollection {
array.items = append(array.items, s)
return array
}
func (array DomainDetails) Get(i int) *DomainInfo {
return &array.items[i]
}
func (array DomainDetails) Size() int {
return len(array.items)
}

View File

@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
// Down engine work in the daemon. // Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) { func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
return &proto.DownResponse{}, nil maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
} }
// Status returns the daemon status // Status returns the daemon status

View File

@ -8,6 +8,7 @@ import (
"context" "context"
"os" "os"
"os/exec" "os/exec"
"regexp"
"runtime" "runtime"
"strings" "strings"
"time" "time"
@ -89,9 +90,17 @@ func _getInfo() string {
func sysInfo() (serialNumber string, productName string, manufacturer string) { func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo var si sysinfo.SysInfo
si.GetSysInfo() si.GetSysInfo()
isascii := regexp.MustCompile("^[[:ascii:]]+$")
serial := si.Chassis.Serial serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" { if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial serial = si.Product.Serial
} }
return serial, si.Product.Name, si.Product.Vendor if (!isascii.MatchString(serial)) && si.Board.Serial != "" {
serial = si.Board.Serial
}
name := si.Product.Name
if (!isascii.MatchString(name)) && si.Board.Name != "" {
name = si.Board.Name
}
return serial, name, si.Product.Vendor
} }

View File

@ -15,7 +15,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"syscall"
"time" "time"
"unicode" "unicode"
@ -34,6 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@ -62,8 +62,25 @@ func main() {
var errorMSG string var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
var saveLogsInFile bool
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
flag.Parse() flag.Parse()
if saveLogsInFile {
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
err := util.InitLog("trace", logFile)
if err != nil {
log.Errorf("error while initializing log: %v", err)
return
}
}
a := app.NewWithID("NetBird") a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG)) a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
@ -76,8 +93,12 @@ func main() {
if showSettings || showRoutes { if showSettings || showRoutes {
a.Run() a.Run()
} else { } else {
if err := checkPIDFile(); err != nil { running, err := isAnotherProcessRunning()
log.Errorf("check PID file: %v", err) if err != nil {
log.Errorf("error while checking process: %v", err)
}
if running {
log.Warn("another process is running")
return return
} }
client.setDefaultFonts() client.setDefaultFonts()
@ -861,104 +882,3 @@ func openURL(url string) error {
} }
return err return err
} }
// checkPIDFile exists and return error, or write new.
func checkPIDFile() error {
pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid")
if piddata, err := os.ReadFile(pidFile); err == nil {
if pid, err := strconv.Atoi(string(piddata)); err == nil {
if process, err := os.FindProcess(pid); err == nil {
if err := process.Signal(syscall.Signal(0)); err == nil {
return fmt.Errorf("process already exists: %d", pid)
}
}
}
}
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}
func (s *serviceClient) setDefaultFonts() {
var (
defaultFontPath string
)
//TODO: Linux Multiple Language Support
switch runtime.GOOS {
case "darwin":
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
case "windows":
fontPath := s.getWindowsFontFilePath()
defaultFontPath = fontPath
}
_, err := os.Stat(defaultFontPath)
if err == nil {
os.Setenv("FYNE_FONT", defaultFontPath)
}
}
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
/*
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
*/
var (
fontFolder string = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
output, err := cmd.Output()
if err != nil {
log.Errorf("Failed to get Windows default language setting: %v", err)
fontPath = path.Join(fontFolder, fontMapping["default"])
return
}
defaultLanguage := strings.TrimSpace(string(output))
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
return
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
fontPath = path.Join(fontFolder, font)
} else {
fontPath = path.Join(fontFolder, fontMapping["default"])
}
return
}

26
client/ui/font_bsd.go Normal file
View File

@ -0,0 +1,26 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package main
import (
"os"
"runtime"
log "github.com/sirupsen/logrus"
)
const defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
func (s *serviceClient) setDefaultFonts() {
// TODO: add other bsd paths
if runtime.GOOS != "darwin" {
return
}
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}

7
client/ui/font_linux.go Normal file
View File

@ -0,0 +1,7 @@
//go:build !386
package main
func (s *serviceClient) setDefaultFonts() {
//TODO: Linux Multiple Language Support
}

91
client/ui/font_windows.go Normal file
View File

@ -0,0 +1,91 @@
package main
import (
"os"
"path"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
func (s *serviceClient) setDefaultFonts() {
defaultFontPath := s.getWindowsFontFilePath()
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}
func (s *serviceClient) getWindowsFontFilePath() string {
var (
fontFolder = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
// getUserDefaultLocaleName.Call() panics if the func is not found
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %v", r)
}
}()
kernel32 := windows.NewLazySystemDLL("kernel32.dll")
getUserDefaultLocaleName := kernel32.NewProc("GetUserDefaultLocaleName")
buf := make([]uint16, 85) // LOCALE_NAME_MAX_LENGTH is usually 85
r, _, err := getUserDefaultLocaleName.Call(uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)))
// returns 0 on failure, err is always non-nil
// https://learn.microsoft.com/en-us/windows/win32/api/winnls/nf-winnls-getuserdefaultlocalename
if r == 0 {
log.Errorf("GetUserDefaultLocaleName call failed: %v", err)
return path.Join(fontFolder, fontMapping["default"])
}
defaultLanguage := windows.UTF16ToString(buf)
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
return path.Join(fontFolder, fontMapping["nirmala"])
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
return path.Join(fontFolder, font)
}
return path.Join(fontFolder, fontMapping["default"])
}

37
client/ui/process.go Normal file
View File

@ -0,0 +1,37 @@
package main
import (
"os"
"path/filepath"
"strings"
"github.com/shirou/gopsutil/v3/process"
)
func isAnotherProcessRunning() (bool, error) {
processes, err := process.Processes()
if err != nil {
return false, err
}
pid := os.Getpid()
processName := strings.ToLower(filepath.Base(os.Args[0]))
for _, p := range processes {
if int(p.Pid) == pid {
continue
}
runningProcessPath, err := p.Exe()
// most errors are related to short-lived processes
if err != nil {
continue
}
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
return true, nil
}
}
return false, nil
}

View File

@ -0,0 +1,26 @@
//go:build !windows
package main
import (
"os"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
currentUserID := os.Getuid()
uids, err := p.Uids()
if err != nil {
log.Errorf("get process uids: %v", err)
return false
}
for _, id := range uids {
log.Debugf("checking process uid: %d", id)
if int(id) == currentUserID {
return true
}
}
return false
}

View File

@ -0,0 +1,24 @@
package main
import (
"os/user"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
processUsername, err := p.Username()
if err != nil {
log.Errorf("get process username error: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Errorf("get current user error: %v", err)
return false
}
return processUsername == currUser.Username
}

10
go.mod
View File

@ -19,12 +19,12 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2 github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.23.0 golang.org/x/crypto v0.24.0
golang.org/x/sys v0.20.0 golang.org/x/sys v0.21.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.0 google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
) )
@ -85,10 +85,10 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.25.0 golang.org/x/net v0.26.0
golang.org/x/oauth2 v0.19.0 golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0 golang.org/x/sync v0.7.0
golang.org/x/term v0.20.0 golang.org/x/term v0.21.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7

20
go.sum
View File

@ -545,8 +545,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
@ -592,8 +592,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg= golang.org/x/oauth2 v0.19.0 h1:9+E/EZBCbTLNrbN35fHv/a/d/mOBatymz1zbtQrXpIg=
@ -650,8 +650,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@ -659,8 +659,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
@ -719,8 +719,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc=
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@ -64,7 +64,7 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
t.wrapper = newDeviceWrapper(tunDevice) t.wrapper = newDeviceWrapper(tunDevice)
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@ -49,7 +49,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.wrapper,
t.iceBind, t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
err = t.assignAddr() err = t.assignAddr()

View File

@ -64,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.wrapper = newDeviceWrapper(tunDevice) t.wrapper = newDeviceWrapper(tunDevice)
log.Debug("Attaching to interface") log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] ")) t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@ -54,7 +54,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.wrapper,
t.iceBind, t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = newWGUSPConfigurer(t.device, t.name)

View File

@ -57,7 +57,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.wrapper,
t.iceBind, t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
err = t.assignAddr() err = t.assignAddr()

View File

@ -41,6 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
} }
func (t *tunDevice) Create() (wgConfigurer, error) { func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUN(t.name, t.mtu) tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil { if err != nil {
return nil, err return nil, err
@ -52,7 +53,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.wrapper,
t.iceBind, t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
luid := winipcfg.LUID(t.nativeTunDevice.LUID()) luid := winipcfg.LUID(t.nativeTunDevice.LUID())

15
iface/wg_log.go Normal file
View File

@ -0,0 +1,15 @@
package iface
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@ -2,7 +2,6 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@ -51,26 +46,21 @@ type GrpcClient struct {
// NewClient creates a new client to Management service // NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) var conn *grpc.ClientConn
if tlsEnabled { operation := func() error {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
} }
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil { if err != nil {
log.Errorf("failed creating connection to Management Service %v", err) log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err return nil, err
} }
@ -326,25 +316,44 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection) return nil, fmt.Errorf(errMsgNoMgmtConnection)
} }
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil { if err != nil {
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)
return nil, err return nil, err
} }
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel() var resp *proto.EncryptedMessage
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ operation := func() error {
WgPubKey: c.key.PublicKey().String(), mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
Body: loginReq, defer cancel()
})
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil { if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err return nil, err
} }
loginResp := &proto.LoginResponse{} loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil { if err != nil {
log.Errorf("failed to decrypt registration message: %s", err) log.Errorf("failed to decrypt login response: %s", err)
return nil, err return nil, err
} }

View File

@ -69,6 +69,7 @@ type AccountManager interface {
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
@ -95,6 +96,7 @@ type AccountManager interface {
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error

View File

@ -746,3 +746,11 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() StoreEngine { func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine return FileStoreEngine
} }
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

View File

@ -112,61 +112,85 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { var eventsToStore []func()
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
existingGroup, err := account.FindGroupByName(newGroup.Name) if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
if err != nil { existingGroup, err := account.FindGroupByName(newGroup.Name)
s, ok := status.FromError(err) if err != nil {
if !ok || s.ErrorType != status.NotFound { s, ok := status.FromError(err)
return err if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
} }
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are oldGroup := account.Groups[newGroup.ID]
// coming from the IdP that we don't have control of. account.Groups[newGroup.ID] = newGroup
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String() events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
} }
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = am.Store.SaveGroups(account.Id, account.Groups); err != nil {
return err return err
} }
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
// the following snippet tracks the activity and stores the group events in the event store. for _, storeEvent := range eventsToStore {
// It has to happen after all the operations have been successfully performed. storeEvent()
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
if exists {
if oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else { } else {
addedPeers = append(addedPeers, newGroup.Peers...) addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
} }
for _, p := range addedPeers { for _, p := range addedPeers {
@ -175,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue continue
} }
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, peerCopy := peer // copy to avoid closure issues
map[string]any{ eventsToStore = append(eventsToStore, func() {
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
"peer_fqdn": peer.FQDN(am.GetDNSDomain()), map[string]any{
}) "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
} }
for _, p := range removedPeers { for _, p := range removedPeers {
@ -188,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue continue
} }
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, peerCopy := peer // copy to avoid closure issues
map[string]any{ eventsToStore = append(eventsToStore, func() {
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
"peer_fqdn": peer.FQDN(am.GetDNSDomain()), map[string]any{
}) "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
} }
return nil return eventsToStore
} }
// difference returns the elements in `a` that aren't in `b`. // difference returns the elements in `a` that aren't in `b`.

View File

@ -40,6 +40,7 @@ type MockAccountManager struct {
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
@ -64,6 +65,7 @@ type MockAccountManager struct {
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
} }
// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, groups)
}
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
}
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil { if am.DeleteGroupFunc != nil {
@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
} }
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
if am.SaveOrAddUsersFunc != nil {
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented")
}
// DeleteUser mocks DeleteUser of the AccountManager interface // DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil { if am.DeleteUserFunc != nil {

View File

@ -274,10 +274,15 @@ func (s *SqlStore) GetInstallationID() string {
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(peerCopy)
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(&peerCopy)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
@ -311,6 +316,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
return nil return nil
} }
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
}
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore // DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil return nil
@ -662,11 +695,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
} }
file := filepath.Join(dataDir, storeStr) file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ db, err := gorm.Open(sqlite.Open(file), getGormConfig())
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -676,10 +705,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
// NewPostgresqlStore creates a new Postgres store. // NewPostgresqlStore creates a new Postgres store.
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -687,6 +713,14 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
} }
func getGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
}
}
// newPostgresStore initializes a new Postgres store. // newPostgresStore initializes a new Postgres store.
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) { func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
dsn, ok := os.LookupEnv(postgresDsnEnv) dsn, ok := os.LookupEnv(postgresDsnEnv)

View File

@ -41,11 +41,22 @@ func TestSqlite_NewStore(t *testing.T) {
} }
func TestSqlite_SaveAccount_Large(t *testing.T) { func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" { if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("skip large test on non-linux OS due to environment restrictions")
} }
t.Run("SQLite", func(t *testing.T) {
store := newSqliteStore(t)
runLargeTest(t, store)
})
// create store outside to have a better time counter for the test
store := newPostgresqlStore(t)
t.Run("PostgreSQL", func(t *testing.T) {
runLargeTest(t, store)
})
}
store := newSqliteStore(t) func runLargeTest(t *testing.T, store Store) {
t.Helper()
account := newAccountWithId(context.Background(), "account_id", "testuser", "") account := newAccountWithId(context.Background(), "account_id", "testuser", "")
groupALL, err := account.GetGroupAll() groupALL, err := account.GetGroupAll()
@ -54,7 +65,7 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
} }
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000 const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ { for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4() netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
@ -362,7 +373,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// save status of non-existing peer // save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err) assert.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
@ -377,7 +388,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
err = store.SaveAccount(context.Background(), account) err = store.SaveAccount(context.Background(), account)

View File

@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@ -41,6 +42,8 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string GetInstallationID() string

View File

@ -740,7 +740,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return pats, nil return pats, nil
} }
// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
} }
@ -748,11 +748,31 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock() defer unlock()
if update == nil { updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") if err != nil {
return nil, err
}
if len(updatedUsers) == 0 {
return nil, status.Errorf(status.Internal, "user was not updated")
}
return updatedUsers[0], nil
}
// SaveOrAddUsers updates existing users or adds new users to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
if len(updates) == 0 {
return nil, nil //nolint:nilnil
} }
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -769,144 +789,200 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
} }
oldUser := account.Users[update.Id] updatedUsers := make([]*UserInfo, 0, len(updates))
if oldUser == nil { var (
if !addIfNotExists { expiredPeers []*nbpeer.Peer
return nil, status.Errorf(status.NotFound, "user to update doesn't exist") eventsToStore []func()
)
for _, update := range updates {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
} }
// when addIfNotExists is set to true the newUser will use all fields from the update input
oldUser = update
}
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked { oldUser := account.Users[update.Id]
return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") if oldUser == nil {
} if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
if initiatorUser.HasAdminPower() && initiatorUserID == update.Id && update.Role != initiatorUser.Role { }
return nil, status.Errorf(status.PermissionDenied, "admins can't change their role") // when addIfNotExists is set to true, the newUser will use all fields from the update input
} oldUser = update
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
return nil, status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return nil, status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
return nil, status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
return nil, status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
transferedOwnerRole := false
if initiatorUser.Role == UserRoleOwner && initiatorUserID != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUserID] = newInitiatorUser
transferedOwnerRole = true
}
// only auto groups, revoked status, and integration reference can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
// these two fields can't be set via API, only via direct call to the method
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
} }
}
newUser.AutoGroups = update.AutoGroups
account.Users[newUser.Id] = newUser if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err return nil, err
} }
if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil { // only auto groups, revoked status, and integration reference can be updated for now
newUser := oldUser.Copy()
newUser.Role = update.Role
newUser.Blocked = update.Blocked
newUser.AutoGroups = update.AutoGroups
// these two fields can't be set via API, only via direct call to the method
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
account.Users[newUser.Id] = newUser
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
if err != nil {
return nil, err
}
expiredPeers = append(expiredPeers, blockedPeers...)
}
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
}
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil {
return nil, err
}
updatedUsers = append(updatedUsers, updatedUserInfo)
}
if len(expiredPeers) > 0 {
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed update expired peers: %s", err) log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
return nil, err return nil, err
} }
} }
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { account.Network.IncSerial()
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) if err = am.Store.SaveUsers(account.Id, account.Users); err != nil {
// need force update all auto groups in any case they will not be duplicated return nil, err
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) }
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if account.Settings.GroupsPropagationEnabled {
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
} else { }
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err for _, storeEvent := range eventsToStore {
storeEvent()
}
return updatedUsers, nil
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
if newUser.IsBlocked() {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
})
} else {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
})
} }
} }
defer func() { switch {
if oldUser.IsBlocked() != update.IsBlocked() { case transferredOwnerRole:
if update.IsBlocked() { eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
})
case oldUser.Role != newUser.Role:
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
})
}
if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
} else { } else {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
} }
} }
for _, g := range addedGroups {
switch { group := account.GetGroup(g)
case transferedOwnerRole: if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) eventsToStore = append(eventsToStore, func() {
case oldUser.Role != newUser.Role: am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
default:
}
if update.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
} else { })
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
}
} }
} }
}() }
if !isNil(am.idpManager) && !newUser.IsServiceUser { return eventsToStore
userData, err := am.lookupUserInCache(ctx, newUser.Id, account) }
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUser.Id] = newInitiatorUser
return true
}
return false
}
// getUserInfo retrieves the UserInfo for a given User and Account.
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
if !isNil(am.idpManager) && !user.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, user.Id, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newUser.ToUserInfo(userData, account.Settings) return user.ToUserInfo(userData, account.Settings)
} }
return newUser.ToUserInfo(nil, account.Settings) return user.ToUserInfo(nil, account.Settings)
}
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
return status.Errorf(status.PermissionDenied, "admins can't change their role")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
return nil
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
@ -937,7 +1013,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
userObj := account.Users[userID] userObj := account.Users[userID]
if account.Domain != lowerDomain && userObj.Role == UserRoleOwner { if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
account.Domain = lowerDomain account.Domain = lowerDomain
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {

View File

@ -18,6 +18,8 @@ Flags:
--letsencrypt-domain string a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS --letsencrypt-domain string a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS
--port int Server port to listen on (e.g. 10000) (default 10000) --port int Server port to listen on (e.g. 10000) (default 10000)
--ssl-dir string server ssl directory location. *Required only for Let's Encrypt certificates. (default "/var/lib/netbird/") --ssl-dir string server ssl directory location. *Required only for Let's Encrypt certificates. (default "/var/lib/netbird/")
--cert-file string Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect
--cert-key string Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect
Global Flags: Global Flags:
--log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log") --log-file string sets Netbird log path. If console is specified the the log will be output to stdout (default "/var/log/netbird/signal.log")
@ -90,6 +92,9 @@ The Signal Server exposes the following metrics in Prometheus format:
- **registration_delay_milliseconds**: A Histogram metric that measures the time - **registration_delay_milliseconds**: A Histogram metric that measures the time
it took to register a peer in it took to register a peer in
milliseconds. milliseconds.
- **get_registration_delay_milliseconds**: A Histogram metric that measures the time
it took to get a peer registration in
milliseconds.
- **messages_forwarded_total**: A Counter metric that counts the total number of - **messages_forwarded_total**: A Counter metric that counts the total number of
messages forwarded between peers. messages forwarded between peers.
- **message_forward_failures_total**: A Counter metric that counts the total - **message_forward_failures_total**: A Counter metric that counts the total

View File

@ -2,7 +2,6 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"sync" "sync"
@ -14,9 +13,6 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error {
// NewClient creates a new Signal client // NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) operation := func() error {
var err error
if tlsEnabled { conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
} }
sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout) err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil { if err != nil {
log.Errorf("failed to connect to the signalling server %v", err) log.Errorf("failed to connect to the signalling server: %v", err)
return nil, err return nil, err
} }
@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
if err != nil { if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something?? // todo send something??
} }
} }
} }

View File

@ -2,15 +2,12 @@ package cmd
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io"
"io/fs"
"net" "net"
"net/http" "net/http"
"os"
"path"
"strings" "strings"
"time" "time"
@ -41,7 +38,8 @@ var (
signalLetsencryptDomain string signalLetsencryptDomain string
signalSSLDir string signalSSLDir string
defaultSignalSSLDir string defaultSignalSSLDir string
tlsEnabled bool signalCertFile string
signalCertKey string
signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{ signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 5 * time.Second, MinTime: 5 * time.Second,
@ -56,12 +54,22 @@ var (
}) })
runCmd = &cobra.Command{ runCmd = &cobra.Command{
Use: "run", Use: "run",
Short: "start NetBird Signal Server daemon", Short: "start NetBird Signal Server daemon",
SilenceUsage: true,
PreRun: func(cmd *cobra.Command, args []string) { PreRun: func(cmd *cobra.Command, args []string) {
err := util.InitLog(logLevel, logFile)
if err != nil {
log.Fatalf("failed initializing log %v", err)
}
flag.Parse()
// detect whether user specified a port // detect whether user specified a port
userPort := cmd.Flag("port").Changed userPort := cmd.Flag("port").Changed
if signalLetsencryptDomain != "" {
tlsEnabled := false
if signalLetsencryptDomain != "" || (signalCertFile != "" && signalCertKey != "") {
tlsEnabled = true tlsEnabled = true
} }
@ -77,33 +85,12 @@ var (
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse() flag.Parse()
err := util.InitLog(logLevel, logFile) opts, certManager, err := getTLSConfigurations()
if err != nil { if err != nil {
log.Fatalf("failed initializing log %v", err) return err
} }
if signalSSLDir == "" { metricsServer, err := metrics.NewServer(metricsPort, "")
oldPath := "/var/lib/wiretrustee"
if migrateToNetbird(oldPath, defaultSignalSSLDir) {
if err := cpDir(oldPath, defaultSignalSSLDir); err != nil {
log.Fatal(err)
}
}
}
var opts []grpc.ServerOption
var certManager *autocert.Manager
if tlsEnabled {
// Let's encrypt enabled -> generate certificate automatically
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
return err
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
opts = append(opts, grpc.Creds(transportCredentials))
}
metricsServer := metrics.NewServer(metricsPort, "")
if err != nil { if err != nil {
return fmt.Errorf("setup metrics: %v", err) return fmt.Errorf("setup metrics: %v", err)
} }
@ -124,7 +111,25 @@ var (
} }
proto.RegisterSignalExchangeServer(grpcServer, srv) proto.RegisterSignalExchangeServer(grpcServer, srv)
grpcRootHandler := grpcHandlerFunc(grpcServer)
if certManager != nil {
startServerWithCertManager(certManager, grpcRootHandler)
}
var compatListener net.Listener var compatListener net.Listener
var grpcListener net.Listener
var httpListener net.Listener
// If certManager is configured and signalPort == 443, then the gRPC server has already been started
if certManager == nil || signalPort != 443 {
grpcListener, err = serveGRPC(grpcServer, signalPort)
if err != nil {
return err
}
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
}
if signalPort != 10000 { if signalPort != 10000 {
// The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal
// are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000.
@ -135,28 +140,6 @@ var (
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
} }
var grpcListener net.Listener
var httpListener net.Listener
if tlsEnabled {
httpListener = certManager.Listener()
if signalPort == 443 {
// running gRPC and HTTP cert manager on the same port
serveHTTP(httpListener, certManager.HTTPHandler(grpcHandlerFunc(grpcServer)))
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
} else {
serveHTTP(httpListener, certManager.HTTPHandler(nil))
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
}
}
if signalPort != 443 || !tlsEnabled {
grpcListener, err = serveGRPC(grpcServer, signalPort)
if err != nil {
return err
}
log.Infof("running gRPC server: %s", grpcListener.Addr().String())
}
log.Infof("signal server version %s", version.NetbirdVersion()) log.Infof("signal server version %s", version.NetbirdVersion())
log.Infof("started Signal Service") log.Infof("started Signal Service")
@ -190,6 +173,58 @@ var (
} }
) )
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
var (
err error
certManager *autocert.Manager
tlsConfig *tls.Config
)
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
log.Infof("running without TLS")
return nil, nil, nil
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
if err != nil {
return nil, certManager, err
}
tlsConfig = certManager.TLSConfig()
log.Infof("setting up TLS with LetsEncrypt.")
} else {
if signalCertFile == "" || signalCertKey == "" {
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
}
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
return nil, certManager, err
}
log.Infof("setting up TLS with custom certificates.")
}
transportCredentials := credentials.NewTLS(tlsConfig)
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
}
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
// a call to certManager.Listener() always creates a new listener so we do it once
httpListener := certManager.Listener()
if signalPort == 443 {
// running gRPC and HTTP cert manager on the same port
serveHTTP(httpListener, certManager.HTTPHandler(grpcRootHandler))
log.Infof("running HTTP server (LetsEncrypt challenge handler) and gRPC server on the same port: %s", httpListener.Addr().String())
} else {
// Start the HTTP cert manager server separately
serveHTTP(httpListener, certManager.HTTPHandler(nil))
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", httpListener.Addr().String())
}
}
func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
@ -232,95 +267,29 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
return listener, nil return listener, nil
} }
func cpFile(src, dst string) error { func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
var err error // Load server's certificate and private key
var srcfd *os.File serverCert, err := tls.LoadX509KeyPair(certFile, certKey)
var dstfd *os.File
var srcinfo os.FileInfo
if srcfd, err = os.Open(src); err != nil {
return err
}
defer srcfd.Close()
if dstfd, err = os.Create(dst); err != nil {
return err
}
defer dstfd.Close()
if _, err = io.Copy(dstfd, srcfd); err != nil {
return err
}
if srcinfo, err = os.Stat(src); err != nil {
return err
}
return os.Chmod(dst, srcinfo.Mode())
}
func copySymLink(source, dest string) error {
link, err := os.Readlink(source)
if err != nil { if err != nil {
return err return nil, err
}
return os.Symlink(link, dest)
}
func cpDir(src string, dst string) error {
var err error
var fds []os.DirEntry
var srcinfo os.FileInfo
if srcinfo, err = os.Stat(src); err != nil {
return err
} }
if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil { // NewDefaultAppMetrics the credentials and return it
return err config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
} }
if fds, err = os.ReadDir(src); err != nil { return config, nil
return err
}
for _, fd := range fds {
srcfp := path.Join(src, fd.Name())
dstfp := path.Join(dst, fd.Name())
fileInfo, err := os.Stat(srcfp)
if err != nil {
log.Fatalf("Couldn't get fileInfo; %v", err)
}
switch fileInfo.Mode() & os.ModeType {
case os.ModeSymlink:
if err = copySymLink(srcfp, dstfp); err != nil {
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
}
case os.ModeDir:
if err = cpDir(srcfp, dstfp); err != nil {
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
}
default:
if err = cpFile(srcfp, dstfp); err != nil {
log.Fatalf("Failed to copy from %s to %s; %v", srcfp, dstfp, err)
}
}
}
return nil
}
func migrateToNetbird(oldPath, newPath string) bool {
_, errOld := os.Stat(oldPath)
_, errNew := os.Stat(newPath)
if errors.Is(errOld, fs.ErrNotExist) || errNew == nil {
return false
}
return true
} }
func init() { func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
} }

View File

@ -15,6 +15,7 @@ type AppMetrics struct {
Deregistrations metric.Int64Counter Deregistrations metric.Int64Counter
RegistrationFailures metric.Int64Counter RegistrationFailures metric.Int64Counter
RegistrationDelay metric.Float64Histogram RegistrationDelay metric.Float64Histogram
GetRegistrationDelay metric.Float64Histogram
MessagesForwarded metric.Int64Counter MessagesForwarded metric.Int64Counter
MessageForwardFailures metric.Int64Counter MessageForwardFailures metric.Int64Counter
@ -54,6 +55,12 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err return nil, err
} }
getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...))
if err != nil {
return nil, err
}
messagesForwarded, err := meter.Int64Counter("messages_forwarded_total") messagesForwarded, err := meter.Int64Counter("messages_forwarded_total")
if err != nil { if err != nil {
return nil, err return nil, err
@ -80,6 +87,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
Deregistrations: deregistrations, Deregistrations: deregistrations,
RegistrationFailures: registrationFailures, RegistrationFailures: registrationFailures,
RegistrationDelay: registrationDelay, RegistrationDelay: registrationDelay,
GetRegistrationDelay: getRegistrationDelay,
MessagesForwarded: messagesForwarded, MessagesForwarded: messagesForwarded,
MessageForwardFailures: messageForwardFailures, MessageForwardFailures: messageForwardFailures,

View File

@ -26,10 +26,10 @@ type Metrics struct {
} }
// NewServer initializes and returns a new Metrics instance // NewServer initializes and returns a new Metrics instance
func NewServer(port int, endpoint string) *Metrics { func NewServer(port int, endpoint string) (*Metrics, error) {
exporter, err := prometheus.New() exporter, err := prometheus.New()
if err != nil { if err != nil {
return nil return nil, err
} }
provider := metric.NewMeterProvider(metric.WithReader(exporter)) provider := metric.NewMeterProvider(metric.WithReader(exporter))
@ -57,7 +57,7 @@ func NewServer(port int, endpoint string) *Metrics {
provider: provider, provider: provider,
Endpoint: endpoint, Endpoint: endpoint,
Server: server, Server: server,
} }, nil
} }
// Shutdown stops the metrics server // Shutdown stops the metrics server

View File

@ -23,11 +23,17 @@ const (
labelTypeError = "error" labelTypeError = "error"
labelTypeNotConnected = "not_connected" labelTypeNotConnected = "not_connected"
labelTypeNotRegistered = "not_registered" labelTypeNotRegistered = "not_registered"
labelTypeStream = "stream"
labelTypeMessage = "message"
labelError = "error" labelError = "error"
labelErrorMissingId = "missing_id" labelErrorMissingId = "missing_id"
labelErrorMissingMeta = "missing_meta" labelErrorMissingMeta = "missing_meta"
labelErrorFailedHeader = "failed_header" labelErrorFailedHeader = "failed_header"
labelRegistrationStatus = "status"
labelRegistrationFound = "found"
labelRegistrationNotFound = "not_found"
) )
// Server an instance of a Signal server // Server an instance of a Signal server
@ -61,7 +67,11 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
return nil, fmt.Errorf("peer %s is not registered", msg.Key) return nil, fmt.Errorf("peer %s is not registered", msg.Key)
} }
getRegistrationStart := time.Now()
if dstPeer, found := s.registry.Get(msg.RemoteKey); found { if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
//forward the message to the target peer //forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil { if err := dstPeer.Stream.Send(msg); err != nil {
log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
@ -69,9 +79,11 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
} else { } else {
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
s.metrics.MessagesForwarded.Add(context.Background(), 1) s.metrics.MessagesForwarded.Add(context.Background(), 1)
} }
} else { } else {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
//todo respond to the sender? //todo respond to the sender?
@ -118,28 +130,30 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
} else if err != nil { } else if err != nil {
return err return err
} }
start := time.Now()
log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
getRegistrationStart := time.Now()
// lookup the target peer where the message is going to // lookup the target peer where the message is going to
if dstPeer, found := s.registry.Get(msg.RemoteKey); found { if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
//forward the message to the target peer //forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil { if err := dstPeer.Stream.Send(msg); err != nil {
log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err) log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err)
//todo respond to the sender? //todo respond to the sender?
// in milliseconds
s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6)
s.metrics.MessagesForwarded.Add(stream.Context(), 1)
} else {
s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
} else {
// in milliseconds
s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(stream.Context(), 1)
} }
} else { } else {
s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey)
//todo respond to the sender? //todo respond to the sender?
s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
} }
} }
<-stream.Context().Done() <-stream.Context().Done()

View File

@ -2,12 +2,18 @@ package grpc
import ( import (
"context" "context"
"crypto/tls"
"net" "net"
"os/user" "os/user"
"runtime" "runtime"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption {
return conn, nil return conn, nil
}) })
} }
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
b.Clock = backoff.SystemClock
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
}
return conn, nil
}

View File

@ -4,6 +4,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"slices"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@ -18,8 +19,9 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err) log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err return err
} }
customOutputs := []string{"console", "syslog"};
if logPath != "" && logPath != "console" { if logPath != "" && !slices.Contains(customOutputs, logPath) {
lumberjackLogger := &lumberjack.Logger{ lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic // Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath), Filename: filepath.ToSlash(logPath),
@ -29,6 +31,8 @@ func InitLog(logLevel string, logPath string) error {
Compress: true, Compress: true,
} }
log.SetOutput(io.Writer(lumberjackLogger)) log.SetOutput(io.Writer(lumberjackLogger))
} else if logPath == "syslog" {
AddSyslogHook()
} }
if os.Getenv("NB_LOG_FORMAT") == "json" { if os.Getenv("NB_LOG_FORMAT") == "json" {

20
util/syslog_nonwindows.go Normal file
View File

@ -0,0 +1,20 @@
//go:build !windows
// +build !windows
package util
import (
"log/syslog"
log "github.com/sirupsen/logrus"
lSyslog "github.com/sirupsen/logrus/hooks/syslog"
)
func AddSyslogHook() {
hook, err := lSyslog.NewSyslogHook("", "", syslog.LOG_INFO, "")
if err != nil {
log.Errorf("Failed creating syslog hook: %s", err)
}
log.AddHook(hook)
}

6
util/syslog_windows.go Normal file
View File

@ -0,0 +1,6 @@
package util
func AddSyslogHook() {
// The syslog package is not available for Windows. This adapter is needed
// to handle windows build.
}

1
versioninfo.json Normal file
View File

@ -0,0 +1 @@
{}