Merge remote-tracking branch 'origin/main' into feat/multiple-profile

This commit is contained in:
Hakan Sariman 2025-05-21 16:09:57 +03:00
commit 7fae260faa
145 changed files with 6533 additions and 2928 deletions

View File

@ -37,17 +37,22 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following debug output To help us resolve the problem, please attach the following anonymized status output
netbird status -dA netbird status -dA
As well as the file created by Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@ -179,6 +179,7 @@ jobs:
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true' grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@ -100,7 +101,7 @@ var loginCmd = &cobra.Command{
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
} }
} }
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@ -40,6 +40,7 @@ const (
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access" blockLANAccessFlag = "block-lan-access"
enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle" uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url" uploadBundleURL = "upload-bundle-url"
) )
@ -80,6 +81,7 @@ var (
blockLANAccess bool blockLANAccess bool
debugUploadBundle bool debugUploadBundle bool
debugUploadBundleURL string debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@ -184,6 +186,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))

View File

@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected": case "", "idle", "connecting", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
default: default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter) return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
} }
if len(ipsFilter) > 0 { if len(ipsFilter) > 0 {

View File

@ -194,6 +194,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.BlockLANAccess = &blockLANAccess ic.BlockLANAccess = &blockLANAccess
} }
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return err
@ -268,7 +272,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
@ -332,6 +336,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@ -201,14 +201,30 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey) stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) return nil, fmt.Errorf("wgctl: %w", err)
} }
return WGStats{ defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime, LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes, TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes, RxBytes: peer.ReceiveBytes,
}, nil }
}
return stats, nil
} }

View File

@ -1,6 +1,7 @@
package configurer package configurer
import ( import (
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
@ -17,6 +18,13 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
@ -217,91 +225,75 @@ func (t *WGUSPConfigurer) Close() {
} }
} }
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err) return nil, fmt.Errorf("ipc get: %w", err)
} }
stats, err := findPeerInfo(ipc, peerKey, []string{ return parseTransfers(ipc)
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
}
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
}
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
}
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
} }
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) { func parseTransfers(ipc string) (map[string]WGStats, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) stats := make(map[string]WGStats)
if err != nil { var (
return nil, fmt.Errorf("parse key: %w", err) currentKey string
} currentStats WGStats
hasPeer bool
hexKey := hex.EncodeToString(peerKeyParsed[:]) )
lines := strings.Split(ipc, "\n")
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key, // If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop. // this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer { if strings.HasPrefix(line, "public_key=") {
break peerID := strings.TrimPrefix(line, "public_key=")
h, err := hex.DecodeString(peerID)
if err != nil {
return nil, fmt.Errorf("decode peerID: %w", err)
}
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
} }
// Identify the peer with the specific public key if !hasPeer {
if line == fmt.Sprintf("public_key=%s", hexKey) { continue
foundPeer = true
} }
for _, key := range searchConfigKeys { key := strings.SplitN(line, "=", 2)
if foundPeer && strings.HasPrefix(line, key+"=") { if len(key) != 2 {
v := strings.SplitN(line, "=", 2) continue
configFound[v[0]] = v[1]
} }
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
}
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
} }
} }
// todo: use multierr return stats, nil
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
} }
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
@ -355,6 +347,18 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String() return sb.String()
} }
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark

View File

@ -2,10 +2,8 @@ package configurer
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -34,58 +32,35 @@ errno=0
` `
func Test_findPeerInfo(t *testing.T) { func Test_parseTransfers(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
peerKey string peerKey string
searchKeys []string want WGStats
want map[string]string
wantErr bool
}{ }{
{ {
name: "single", name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
searchKeys: []string{"tx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 0,
"tx_bytes": "38333", RxBytes: 0,
}, },
wantErr: false,
}, },
{ {
name: "multiple", name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 38333,
"tx_bytes": "38333", RxBytes: 2224,
"rx_bytes": "2224",
}, },
wantErr: false,
}, },
{ {
name: "lastpeer", name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58", peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 1212111,
"tx_bytes": "1212111", RxBytes: 1929999999,
"rx_bytes": "1929999999",
}, },
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
key, err := wgtypes.NewKey(res) key, err := wgtypes.NewKey(res)
require.NoError(t, err) require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys) stats, err := parseTransfers(ipcFixture)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)) if err != nil {
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys) require.NoError(t, err)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
}) })
} }
} }

View File

@ -16,5 +16,5 @@ type WGConfigurer interface {
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
} }

View File

@ -212,9 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device() return w.tun.Device()
} }
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats()
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {

View File

@ -24,6 +24,8 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@ -49,6 +51,10 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@ -58,9 +64,6 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@ -70,13 +73,16 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@ -104,8 +114,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked ${NSD_Check} $AutostartCheckbox
StrCpy $AutostartEnabled "1" ; Default to enabled StrCpy $AutostartEnabled "1"
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@ -115,6 +125,30 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@ -225,31 +259,58 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@ -76,12 +76,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)

View File

@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@ -101,8 +101,13 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
} }
if !p.providerConfig.DisablePromptLogin { if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login")) params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
} }
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@ -7,15 +7,36 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
) )
func TestPromptLogin(t *testing.T) { func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct { tt := []struct {
name string name string
prompt bool loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{ }{
{"PromptLogin", true}, {
{"NoPromptLogin", false}, name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
} }
for _, tc := range tt { for _, tc := range tt {
@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize", AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"}, RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true, UseIDToken: true,
DisablePromptLogin: !tc.prompt, LoginFlag: tc.loginFlag,
} }
pkce, err := NewPKCEAuthorizationFlow(config) pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil { if err != nil {
@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to request auth info: %v", err) t.Fatalf("Failed to request auth info: %v", err)
} }
pattern := "prompt=login"
if tc.prompt { if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, pattern) require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else { } else {
require.NotContains(t, authInfo.VerificationURIComplete, pattern) require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
} }
}) })
} }

View File

@ -74,6 +74,8 @@ type ConfigInput struct {
DisableNotifications *bool DisableNotifications *bool
DNSLabels domain.List DNSLabels domain.List
LazyConnectionEnabled *bool
} }
// Config Configuration type // Config Configuration type
@ -138,6 +140,8 @@ type Config struct {
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@ -524,6 +528,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
return updated, nil return updated, nil
} }

303
client/internal/conn_mgr.go Normal file
View File

@ -0,0 +1,303 @@
package internal
import (
"context"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
// - Managing lazy connections via the lazyConnManager
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
return nil
}
if enabled {
// if the lazy connection manager is already started, do not start it again
if e.lazyConnMgr != nil {
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager(ctx)
} else {
if e.lazyConnMgr == nil {
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
e.closeManager(ctx)
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs []string) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for _, peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
continue
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(e.ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
}
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
return true
}
if !e.isStartedWithLazyMgr() {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if !lazyconn.IsSupported(conn.AgentVersionString()) {
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerKey,
AllowedIPs: conn.WgConfig().AllowedIps,
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if excluded {
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
conn.Log.Infof("peer added to lazy conn manager")
return
}
func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn, ok := e.peerStore.Remove(peerKey)
if !ok {
return
}
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
}
e.lazyConnMgr.RemovePeer(peerKey)
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(e.ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
func (e *ConnMgr) Close() {
if !e.isStartedWithLazyMgr() {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
ctx, cancel := context.WithCancel(parentCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(ctx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
if e.lazyConnMgr == nil {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
for _, peerID := range e.peerStore.PeersPubKey() {
e.peerStore.PeerConnOpen(ctx, peerID)
}
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.ctxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {
return nil
}
parsedMinutes, err := strconv.Atoi(envValue)
if err != nil || parsedMinutes <= 0 {
return nil
}
d := time.Duration(parsedMinutes) * time.Minute
return &d
}

View File

@ -441,6 +441,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess, BlockLANAccess: config.BlockLANAccess,
LazyConnectionEnabled: config.LazyConnectionEnabled,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@ -481,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
return signalClient, nil return signalClient, nil
} }
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey() serverPublicKey, err := client.GetServerPublicKey()

View File

@ -376,6 +376,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {

View File

@ -1,7 +1,6 @@
package dns_test package dns_test
import ( import (
"net"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -9,6 +8,7 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
) )
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
r.SetQuestion("example.com.", dns.TypeA) r.SetQuestion("example.com.", dns.TypeA)
// Create test writer // Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations - only highest priority handler should be called // Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request // Create and execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify expectations // Verify expectations
@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
}).Once() }).Once()
// Execute // Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify all handlers were called in order // Verify all handlers were called in order
@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3.AssertExpectations(t) handler3.AssertExpectations(t)
} }
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) { func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations // Setup expectations
for priority, handler := range handlers { for priority, handler := range handlers {
@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Highest priority handler (routeHandler) should be called // Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 2: Remove highest priority handler // Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called // Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks { for _, m := range mocks {
@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
// Execute request // Execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r) chain.ServeDNS(&test.MockResponseWriter{}, r)
// Verify each handler was called exactly as expected // Verify each handler was called exactly as expected
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations // Setup handler expectations
for pattern, handler := range handlers { for pattern, handler := range handlers {
@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
handler := &nbdns.MockHandler{} handler := &nbdns.MockHandler{}
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA) r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any // First verify no handler is called before adding any
chain.ServeDNS(w, r) chain.ServeDNS(w, r)

View File

@ -1,130 +0,0 @@
package dns
import (
"fmt"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type registrationMap map[string]struct{}
type localResolver struct {
registeredMap registrationMap
records sync.Map // key: string (domain_class_type), value: []dns.RR
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(r)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
if len(r.Question) == 0 {
return nil
}
question := r.Question[0]
question.Name = strings.ToLower(question.Name)
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
value, found := d.records.Load(key)
if !found {
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
r.Question[0].Qtype = dns.TypeCNAME
return d.lookupRecords(r)
}
return nil
}
records, ok := value.([]dns.RR)
if !ok {
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
return nil
}
// if there's more than one record, rotate them (round-robin)
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records.Store(key, records)
}
return records
}
// registerRecord stores a new record by appending it to any existing list
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
rr, err := dns.NewRR(record.String())
if err != nil {
return "", fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
// load any existing slice of records, then append
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
records := existing.([]dns.RR)
records = append(records, rr)
// store updated slice
d.records.Store(key, records)
return key, nil
}
// deleteRecord removes *all* records under the recordKey.
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}
// buildRecordKey consistently generates a key: name_class_type
func buildRecordKey(name string, class, qType uint16) string {
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
}
func (d *localResolver) probeAvailability() {}

View File

@ -0,0 +1,149 @@
package local
import (
"fmt"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
)
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
}
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
// String returns a string representation of the local resolver
func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// TODO: return success if we have a different record type for the same name, relevant for search domains
replyMessage.Rcode = dns.RcodeNameError
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return nil
}
recordsCopy := slices.Clone(records)
d.mu.RUnlock()
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
}
d.mu.Unlock()
}
return recordsCopy
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}
// RegisterRecord stores a new record by appending it to any existing list
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.registerRecord(record)
}
// registerRecord performs the registration with the lock already held
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
rr, err := dns.NewRR(record.String())
if err != nil {
return fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
q := dns.Question{
Name: strings.ToLower(dns.Fqdn(header.Name)),
Qtype: header.Rrtype,
Qclass: header.Class,
}
d.records[q] = append(d.records[q], rr)
return nil
}

View File

@ -0,0 +1,472 @@
package local
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := NewResolver()
_ = resolver.RegisterRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}
// TestLocalResolver_Update_StaleRecord verifies that updating
// a record correctly replaces the old one, preventing stale entries.
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
recordName := "host.example.com."
recordType := dns.TypeA
recordClass := dns.ClassINET
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
}
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
rrSlice1, found1 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found1, "Record key %s not found after first update", recordKey)
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
rrSlice2, found2 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found2, "Record key %s not found after second update", recordKey)
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
}
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
// with the same question are stored properly
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
resolver := NewResolver()
recordName := "multi.example.com."
recordType := dns.TypeA
// Create two records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
Name: recordName,
Qtype: recordType,
Qclass: dns.ClassINET,
}
// Verify both records are stored
resolver.mu.RLock()
records, found := resolver.records[question]
resolver.mu.RUnlock()
require.True(t, found, "Records for question %v not found", question)
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
// Verify both record data values are present
recordStrings := []string{records[0].String(), records[1].String()}
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
}
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
func TestLocalResolver_RecordRotation(t *testing.T) {
resolver := NewResolver()
recordName := "rotation.example.com."
recordType := dns.TypeA
// Create three records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
}
record3 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
// First lookup - should return the records in original order
var responses [3]*dns.Msg
// Perform three lookups to verify rotation
for i := 0; i < 3; i++ {
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responses[i] = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
}
// Verify all three responses contain answers
for i, resp := range responses {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
}
// Verify the first record in each response is different due to rotation
firstRecordIPs := []string{
responses[0].Answer[0].String(),
responses[1].Answer[0].String(),
responses[2].Answer[0].String(),
}
// Each record should be different (rotated)
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
// After three rotations, we should have cycled through all records
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
}
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
resolver := NewResolver()
// Create record with lowercase name
lowerCaseRecord := nbdns.SimpleRecord{
Name: "lower.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.10.10.10",
}
// Create record with mixed case name
mixedCaseRecord := nbdns.SimpleRecord{
Name: "MiXeD.ExAmPlE.CoM.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "20.20.20.20",
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
queryName string
expectedRData string
shouldResolve bool
}{
{
name: "Query lowercase with lowercase record",
queryName: "lower.example.com.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query uppercase with lowercase record",
queryName: "LOWER.EXAMPLE.COM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query mixed case with lowercase record",
queryName: "LoWeR.eXaMpLe.CoM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query lowercase with mixed case record",
queryName: "mixed.example.com.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query uppercase with mixed case record",
queryName: "MIXED.EXAMPLE.COM.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query with different casing pattern",
queryName: "mIxEd.ExaMpLe.cOm.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query non-existent domain",
queryName: "nonexistent.example.com.",
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case name
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 {
// Expected no answer, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected IP address %s, got: %s",
tc.expectedRData, answerString)
})
}
}
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
// to checking for CNAME records when the requested record type isn't found
func TestLocalResolver_CNAMEFallback(t *testing.T) {
resolver := NewResolver()
// Create a CNAME record (but no A record for this name)
cnameRecord := nbdns.SimpleRecord{
Name: "alias.example.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
// Create an A record for the CNAME target
targetRecord := nbdns.SimpleRecord{
Name: "target.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.100.100",
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
queryName string
queryType uint16
expectedType string
expectedRData string
shouldResolve bool
}{
{
name: "Directly query CNAME record",
queryName: "alias.example.com.",
queryType: dns.TypeCNAME,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query A record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query AAAA record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeAAAA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query direct A record",
queryName: "target.example.com.",
queryType: dns.TypeA,
expectedType: "A",
expectedRData: "192.168.100.100",
shouldResolve: true,
},
{
name: "Query non-existent name",
queryName: "nonexistent.example.com.",
queryType: dns.TypeA,
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case parameters
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
// Expected no resolution, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a successful response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedType,
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
})
}
}

View File

@ -1,88 +0,0 @@
package dns
import (
"strings"
"testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_, _ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@ -1,26 +0,0 @@
package dns
import (
"net"
"github.com/miekg/dns"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@ -15,6 +15,8 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@ -46,8 +48,6 @@ type Server interface {
ProbeAvailability() ProbeAvailability()
} }
type handlerID string
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
domain string domain string
groups []*nbdns.NameServerGroup groups []*nbdns.NameServerGroup
@ -61,7 +61,7 @@ type DefaultServer struct {
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxMap registeredHandlerMap dnsMuxMap registeredHandlerMap
localResolver *localResolver localResolver *local.Resolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
updateSerial uint64 updateSerial uint64
@ -84,9 +84,9 @@ type DefaultServer struct {
type handlerWithStop interface { type handlerWithStop interface {
dns.Handler dns.Handler
stop() Stop()
probeAvailability() ProbeAvailability()
id() handlerID ID() types.HandlerID
} }
type handlerWrapper struct { type handlerWrapper struct {
@ -95,7 +95,7 @@ type handlerWrapper struct {
priority int priority int
} }
type registeredHandlerMap map[handlerID]handlerWrapper type registeredHandlerMap map[types.HandlerID]handlerWrapper
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(
@ -178,9 +178,7 @@ func newDefaultServer(
handlerChain: handlerChain, handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{ localResolver: local.NewResolver(),
registeredMap: make(registrationMap),
},
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager, stateManager: stateManager,
@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Add(1) wg.Add(1)
go func(mux handlerWithStop) { go func(mux handlerWithStop) {
defer wg.Done() defer wg.Done()
mux.probeAvailability() mux.ProbeAvailability()
}(mux.handler) }(mux.handler)
} }
wg.Wait() wg.Wait()
@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.service.Stop() s.service.Stop()
} }
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil { if err != nil {
return fmt.Errorf("local handler updater: %w", err) return fmt.Errorf("local handler updater: %w", err)
} }
@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
// register local records // register local records
s.updateLocalResolver(localRecordsByDomain) s.localResolver.Update(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
) )
} }
func (s *DefaultServer) buildLocalHandlerUpdate( func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
customZones []nbdns.CustomZone,
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
var muxUpdates []handlerWrapper var muxUpdates []handlerWrapper
localRecords := make(map[string][]nbdns.SimpleRecord) var localRecords []nbdns.SimpleRecord
for _, customZone := range customZones { for _, customZone := range customZones {
if len(customZone.Records) == 0 { if len(customZone.Records) == 0 {
@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}) })
// group all records under this domain
for _, record := range customZone.Records { for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass { if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class) log.Warnf("received an invalid class type: %s", record.Class)
continue continue
} }
// zone records contain the fqdn, so we can just flatten them
key := buildRecordKey(record.Name, class, uint16(record.Type)) localRecords = append(localRecords, record)
localRecords[key] = append(localRecords[key], record)
} }
} }
@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
handler.stop() handler.Stop()
log.Errorf("received a nameserver group with an invalid nameserver list") log.Errorf("received a nameserver group with an invalid nameserver list")
continue continue
} }
@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap { for _, existing := range s.dnsMuxMap {
s.deregisterHandler([]string{existing.domain}, existing.priority) s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.stop() existing.handler.Stop()
} }
muxUpdateMap := make(registeredHandlerMap) muxUpdateMap := make(registeredHandlerMap)
@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
containsRootUpdate = true containsRootUpdate = true
} }
s.registerHandler([]string{update.domain}, update.handler, update.priority) s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.id()] = update muxUpdateMap[update.handler.ID()] = update
} }
// If there's no root update and we had a root handler, restore it // If there's no root update and we had a root handler, restore it
@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
} }
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
// remove old records that are no longer present
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
s.localResolver.deleteRecord(key)
}
}
updatedMap := make(registrationMap)
for _, recs := range update {
for _, rec := range recs {
// convert the record to a dns.RR and register
key, err := s.localResolver.registerRecord(rec)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v",
rec.String(), err)
continue
}
updatedMap[key] = struct{}{}
}
}
s.localResolver.registeredMap = updatedMap
}
func getNSHostPort(ns nbdns.NameServer) string { func getNSHostPort(ns nbdns.NameServer) string {
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port) return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
} }

View File

@ -23,6 +23,9 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
} }
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
{ {
IP: netip.MustParseAddr("8.8.8.8"), IP: netip.MustParseAddr("8.8.8.8"),
@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
} }
dummyHandler := &localResolver{} dummyHandler := local.NewResolver()
testCases := []struct { testCases := []struct {
name string name string
initUpstreamMap registeredHandlerMap initUpstreamMap registeredHandlerMap
initLocalMap registrationMap initLocalRecords []nbdns.SimpleRecord
initSerial uint64 initSerial uint64
inputSerial uint64 inputSerial uint64
inputUpdate nbdns.Config inputUpdate nbdns.Config
shouldFail bool shouldFail bool
expectedUpstreamMap registeredHandlerMap expectedUpstreamMap registeredHandlerMap
expectedLocalMap registrationMap expectedLocalQs []dns.Question
}{ }{
{ {
name: "Initial Config Should Succeed", name: "Initial Config Should Succeed",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{ expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
dummyHandler.id(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
generateDummyHandler(".", nameServers).id(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityDefault, priority: PriorityDefault,
}, },
}, },
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: buildRecordKey(zoneRecords[0].Name, 1, 1), domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{ expectedUpstreamMap: registeredHandlerMap{
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@ -216,11 +219,11 @@ func TestUpdateDNSServer(t *testing.T) {
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
}, },
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 2, initSerial: 2,
inputSerial: 1, inputSerial: 1,
@ -228,7 +231,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
{ {
name: "Empty NS Group Domain Or Not Primary Element Should Fail", name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
@ -250,7 +253,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
{ {
name: "Invalid NS Group Nameservers list Should Fail", name: "Invalid NS Group Nameservers list Should Fail",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
@ -272,7 +275,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
{ {
name: "Invalid Custom Zone Records list Should Skip", name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap), initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
@ -290,7 +293,7 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{ expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: ".", domain: ".",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityDefault, priority: PriorityDefault,
@ -298,9 +301,9 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true}, inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap), expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap), expectedLocalQs: []dns.Question{},
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false}, inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap), expectedUpstreamMap: make(registeredHandlerMap),
expectedLocalMap: make(registrationMap), expectedLocalQs: []dns.Question{},
}, },
} }
@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
}() }()
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
} }
} }
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) { var responseMSG *dns.Msg
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap)) responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
} }
for key := range testCase.expectedLocalMap { if len(testCase.expectedLocalQs) > 0 {
_, found := dnsServer.localResolver.registeredMap[key] assert.NotNil(t, responseMSG, "response message should not be nil")
if !found { assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap) assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
} }
}) })
} }
@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
dnsServer.dnsMuxMap = registeredHandlerMap{ dnsServer.dnsMuxMap = registeredHandlerMap{
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &localResolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityMatchDomain,
}, },
} }
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.updateSerial = 0 dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop() defer dnsServer.Stop()
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0]) err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -632,9 +644,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
server := DefaultServer{ server := DefaultServer{
ctx: context.Background(), ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: local.NewResolver(),
registeredMap: make(registrationMap),
},
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: hostManager, hostManager: hostManager,
currentConfig: HostDNSConfig{ currentConfig: HostDNSConfig{
@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tc.query, dns.TypeA) r.SetQuestion(tc.query, dns.TypeA)
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
if mh, ok := tc.expectedHandler.(*MockHandler); ok { if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once() mh.On("ServeDNS", mock.Anything, r).Once()
@ -1037,9 +1047,9 @@ type mockHandler struct {
} }
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) stop() {} func (m *mockHandler) Stop() {}
func (m *mockHandler) probeAvailability() {} func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) id() handlerID { return handlerID(m.Id) } func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{} type mockService struct{}
@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
name string name string
initialHandlers registeredHandlerMap initialHandlers registeredHandlerMap
updates []handlerWrapper updates []handlerWrapper
expectedHandlers map[string]string // map[handlerID]domain expectedHandlers map[string]string // map[HandlerID]domain
description string description string
}{ }{
{ {
@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Check each expected handler // Check each expected handler
for id, expectedDomain := range tt.expectedHandlers { for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[handlerID(id)] handler, exists := server.dnsMuxMap[types.HandlerID(id)]
assert.True(t, exists, "Expected handler %s not found", id) assert.True(t, exists, "Expected handler %s not found", id)
if exists { if exists {
assert.Equal(t, expectedDomain, handler.domain, assert.Equal(t, expectedDomain, handler.domain,
@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
} }
// Verify no unexpected handlers exist // Verify no unexpected handlers exist
for handlerID := range server.dnsMuxMap { for HandlerID := range server.dnsMuxMap {
_, expected := tt.expectedHandlers[string(handlerID)] _, expected := tt.expectedHandlers[string(HandlerID)]
assert.True(t, expected, "Unexpected handler found: %s", handlerID) assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
} }
// Verify the handlerChain state and order // Verify the handlerChain state and order
@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
wgInterface: &mocWGIface{}, wgInterface: &mocWGIface{},
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: mockHostConfig, hostManager: mockHostConfig,
localResolver: &localResolver{}, localResolver: &local.Resolver{},
service: mockSvc, service: mockSvc,
statusRecorder: peer.NewRecorder("test"), statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),

View File

@ -30,9 +30,12 @@ const (
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS" systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusResolvConfModeForeign = "foreign" systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
dnsSecDisabled = "no"
) )
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
Family: unix.AF_INET, Family: unix.AF_INET,
Address: ipAs4[:], Address: ipAs4[:],
} }
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
if err != nil { return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err) }
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSSEC to 'no': %v", err)
} }
var ( var (

View File

@ -0,0 +1,26 @@
package test
import (
"net"
"github.com/miekg/dns"
)
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *MockResponseWriter) Close() error { return nil }
func (rw *MockResponseWriter) TsigStatus() error { return nil }
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
func (rw *MockResponseWriter) Hijack() {}

View File

@ -0,0 +1,3 @@
package types
type HandlerID string

View File

@ -19,6 +19,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
} }
// ID returns the unique handler ID // ID returns the unique handler ID
func (u *upstreamResolverBase) id() handlerID { func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers) servers := slices.Clone(u.upstreamServers)
slices.Sort(servers) slices.Sort(servers)
hash := sha256.New() hash := sha256.New()
hash.Write([]byte(u.domain + ":")) hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ","))) hash.Write([]byte(strings.Join(servers, ",")))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
} }
func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) MatchSubdomains() bool {
return true return true
} }
func (u *upstreamResolverBase) stop() { func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel() u.cancel()
} }
@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
) )
} }
// probeAvailability tests all upstream servers simultaneously and // ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work // disables the resolver if none work
func (u *upstreamResolverBase) probeAvailability() { func (u *upstreamResolverBase) ProbeAvailability() {
u.mutex.Lock() u.mutex.Lock()
defer u.mutex.Unlock() defer u.mutex.Unlock()

View File

@ -8,6 +8,8 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
) )
func TestUpstreamResolver_ServeDNS(t *testing.T) { func TestUpstreamResolver_ServeDNS(t *testing.T) {
@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
} }
var responseMSG *dns.Msg var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{ responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m responseMSG = m
return nil return nil
@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &mockResponseWriter{ responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil }, WriteMsgFunc: func(m *dns.Msg) error { return nil },
} }

View File

@ -5,7 +5,6 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@ -18,5 +17,4 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@ -1,7 +1,6 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@ -13,6 +12,5 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@ -33,6 +33,8 @@ type DNSForwarder struct {
dnsServer *dns.Server dnsServer *dns.Server
mux *dns.ServeMux mux *dns.ServeMux
tcpServer *dns.Server
tcpMux *dns.ServeMux
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager
} }
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress) log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
dnsServer := &dns.Server{ // UDP server
mux := dns.NewServeMux()
f.mux = mux
f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
f.dnsServer = dnsServer // TCP server
f.mux = mux tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux
f.tcpServer = &dns.Server{
Addr: f.listenAddress,
Net: "tcp",
Handler: tcpMux,
}
f.UpdateDomains(entries) f.UpdateDomains(entries)
return dnsServer.ListenAndServe() errCh := make(chan error, 2)
}
go func() {
log.Infof("DNS UDP listener running on %s", f.listenAddress)
errCh <- f.dnsServer.ListenAndServe()
}()
go func() {
log.Infof("DNS TCP listener running on %s", f.listenAddress)
errCh <- f.tcpServer.ListenAndServe()
}()
// return the first error we get (e.g. bind failure or shutdown)
return <-errCh
}
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
} }
oldDomains := filterDomains(f.fwdEntries) oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains { for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString()) f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
} }
newDomains := filterDomains(entries) newDomains := filterDomains(entries)
for _, d := range newDomains { for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery) f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
} }
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil { var result *multierror.Error
return nil
if f.dnsServer != nil {
if err := f.dnsServer.ShutdownContext(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err))
} }
return f.dnsServer.ShutdownContext(ctx) }
if f.tcpServer != nil {
if err := f.tcpServer.ShutdownContext(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err))
}
}
return nberrors.FormatErrorOrNil(result)
} }
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 { if len(query.Question) == 0 {
return return nil
} }
question := query.Question[0] question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
return return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return return nil
} }
f.updateInternalState(domain, ips) f.updateInternalState(domain, ips)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError var dnsErr *net.DNSError
switch { switch {
@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai
} }
if dnsErr.Server != "" { if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }

View File

@ -33,6 +33,7 @@ type Manager struct {
statusRecorder *peer.Status statusRecorder *peer.Status
fwRules []firewall.Rule fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
} }
@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error {
} }
m.fwRules = dnsRules m.fwRules = dnsRules
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
m.tcpRules = tcpRules
return nil return nil
} }
@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
} }
} }
for _, rule := range m.tcpRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
m.fwRules = nil m.fwRules = nil
m.tcpRules = nil
return nberrors.FormatErrorOrNil(mErr) return nberrors.FormatErrorOrNil(mErr)
} }

View File

@ -38,6 +38,7 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
@ -122,6 +123,8 @@ type EngineConfig struct {
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
LazyConnectionEnabled bool
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@ -134,6 +137,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerStore *peerstore.Store peerStore *peerstore.Store
connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook nbnet.RemoveHookFunc
@ -171,6 +176,7 @@ type Engine struct {
sshServer nbssh.Server sshServer nbssh.Server
statusRecorder *peer.Status statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager firewall firewallManager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
}
// stopping network monitor first to avoid starting the engine again // stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil { if e.networkMonitor != nil {
e.networkMonitor.Stop() e.networkMonitor.Stop()
@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
err := e.removeAllPeers() if err := e.removeAllPeers(); err != nil {
if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
} }
@ -405,8 +414,7 @@ func (e *Engine) Start() error {
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
err = e.wgInterfaceCreate() if err = e.wgInterfaceCreate(); err != nil {
if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close() e.close()
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
@ -442,6 +450,11 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
} }
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start() e.srWatcher.Start()
@ -450,7 +463,6 @@ func (e *Engine) Start() error {
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
return nil return nil
} }
@ -550,6 +562,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
if !ok {
continue
}
if currentPeer.AgentVersionString() != p.AgentVersion {
modified = append(modified, p)
continue
}
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey) allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
if !ok { if !ok {
continue continue
@ -559,8 +581,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
continue continue
} }
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
if err != nil {
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
} }
} }
@ -621,17 +642,12 @@ func (e *Engine) removePeer(peerKey string) error {
e.sshServer.RemoveAuthorizedKey(peerKey) e.sshServer.RemoveAuthorizedKey(peerKey)
} }
defer func() { e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey) err := e.statusRecorder.RemovePeer(peerKey)
if err != nil { if err != nil {
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err) log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
} }
}()
conn, exists := e.peerStore.Remove(peerKey)
if exists {
conn.Close()
}
return nil return nil
} }
@ -952,12 +968,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
if e.firewall != nil { if e.firewall != nil {
if localipfw, ok := e.firewall.(localIpUpdater); ok { if localipfw, ok := e.firewall.(localIpUpdater); ok {
if err := localipfw.UpdateLocalIPs(); err != nil { if err := localipfw.UpdateLocalIPs(); err != nil {
log.Errorf("failed to update local IPs: %v", err) log.Errorf("failed to update local IPs: %v", err)
} }
} }
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
// This needs to be toggled before applying routes.
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
} }
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
@ -976,7 +1004,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil { forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err) log.Errorf("failed to update forward rules, err: %v", err)
} }
@ -1022,6 +1051,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
@ -1155,7 +1188,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
IP: strings.Join(offlinePeer.GetAllowedIps(), ","), IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
PubKey: offlinePeer.GetWgPubKey(), PubKey: offlinePeer.GetWgPubKey(),
FQDN: offlinePeer.GetFqdn(), FQDN: offlinePeer.GetFqdn(),
ConnStatus: peer.StatusDisconnected, ConnStatus: peer.StatusIdle,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@ -1191,12 +1224,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerIPs = append(peerIPs, allowedNetIP) peerIPs = append(peerIPs, allowedNetIP)
} }
conn, err := e.createPeerConn(peerKey, peerIPs) conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
if err != nil { if err != nil {
return fmt.Errorf("create peer connection: %w", err) return fmt.Errorf("create peer connection: %w", err)
} }
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close() conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
@ -1205,17 +1243,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
conn.AddBeforeAddPeerHook(e.beforePeerHook) conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook) conn.AddAfterRemovePeerHook(e.afterPeerHook)
} }
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
if err != nil {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
conn.Open()
return nil return nil
} }
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) { func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey) log.Debugf("creating peer connection %s", pubKey)
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
@ -1231,6 +1262,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
config := peer.ConnConfig{ config := peer.ConnConfig{
Key: pubKey, Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(), LocalKey: e.config.WgPrivateKey.PublicKey().String(),
AgentVersion: agentVersion,
Timeout: timeout, Timeout: timeout,
WgConfig: wgConfig, WgConfig: wgConfig,
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
@ -1249,7 +1281,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1270,7 +1311,7 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
conn, ok := e.peerStore.PeerConn(msg.Key) conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
} }
@ -1578,13 +1619,39 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states. // and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool { func (e *Engine) RunHealthProbes() bool {
e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy() signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy) log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy() managementHealthy := e.mgmClient.IsHealthy()
log.Debugf("management health check: healthy=%t", managementHealthy) log.Debugf("management health check: healthy=%t", managementHealthy)
results := append(e.probeSTUNs(), e.probeTURNs()...) stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
}
e.syncMsgMux.Unlock()
results := e.probeICE(stuns, turns)
e.statusRecorder.UpdateRelayStates(results) e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true relayHealthy := true
@ -1596,37 +1663,16 @@ func (e *Engine) RunHealthProbes() bool {
} }
log.Debugf("relay health check: healthy=%t", relayHealthy) log.Debugf("relay health check: healthy=%t", relayHealthy)
for _, key := range e.peerStore.PeersPubKey() {
wgStats, err := e.wgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
continue
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
allHealthy := signalHealthy && managementHealthy && relayHealthy allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy) log.Debugf("all health checks completed: healthy=%t", allHealthy)
return allHealthy return allHealthy
} }
func (e *Engine) probeSTUNs() []relay.ProbeResult { func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
e.syncMsgMux.Lock() return append(
stuns := slices.Clone(e.STUNs) relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
e.syncMsgMux.Unlock() relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
)
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
}
func (e *Engine) probeTURNs() []relay.ProbeResult {
e.syncMsgMux.Lock()
turns := slices.Clone(e.TURNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context // restartEngine restarts the engine by cancelling the client context
@ -1813,21 +1859,21 @@ func (e *Engine) Address() (netip.Addr, error) {
return ip.Unmap(), nil return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil { if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules") log.Warn("firewall is disabled, not updating forwarding rules")
return nil return nil, nil
} }
if len(rules) == 0 { if len(rules) == 0 {
if e.ingressGatewayMgr == nil { if e.ingressGatewayMgr == nil {
return nil return nil, nil
} }
err := e.ingressGatewayMgr.Close() err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil) e.statusRecorder.SetIngressGwMgr(nil)
return err return nil, err
} }
if e.ingressGatewayMgr == nil { if e.ingressGatewayMgr == nil {
@ -1878,7 +1924,33 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
log.Errorf("failed to update forwarding rules: %v", err) log.Errorf("failed to update forwarding rules: %v", err)
} }
return nberrors.FormatErrorOrNil(merr) return forwardingRules, nberrors.FormatErrorOrNil(merr)
}
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) []string {
excludedPeers := make([]string, 0)
for _, r := range routes {
if r.Peer == "" {
continue
}
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers = append(excludedPeers, r.Peer)
}
for _, r := range rules {
ip := r.TranslatedAddress
for _, p := range peers {
for _, allowedIP := range p.GetAllowedIps() {
if allowedIP != ip.String() {
continue
}
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
excludedPeers = append(excludedPeers, p.GetWgPubKey())
}
}
}
return excludedPeers
} }
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.

View File

@ -28,8 +28,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@ -38,6 +36,7 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@ -53,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@ -93,7 +93,7 @@ type MockWGIface struct {
GetFilterFunc func() device.PacketFilter GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
@ -171,8 +171,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc() return m.GetWGDeviceFunc()
} }
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
return m.GetStatsFunc(peerKey) return m.GetStatsFunc()
} }
func (m *MockWGIface) GetProxy() wgproxy.Proxy { func (m *MockWGIface) GetProxy() wgproxy.Proxy {
@ -378,6 +378,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
} }
}, },
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return nil
},
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
@ -400,6 +403,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
type testCase struct { type testCase struct {
name string name string
@ -770,6 +775,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{} engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@ -966,6 +973,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
} }
engine.dnsServer = mockDNSServer engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
engine.connMgr.Start(ctx)
defer func() { defer func() {
exitErr := engine.Stop() exitErr := engine.Stop()
@ -1476,7 +1485,7 @@ func getConnectedPeers(e *Engine) int {
i := 0 i := 0
for _, id := range e.peerStore.PeersPubKey() { for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id) conn, _ := e.peerStore.PeerConn(id)
if conn.Status() == peer.StatusConnected { if conn.IsConnected() {
i++ i++
} }
} }

View File

@ -35,6 +35,6 @@ type wgIfaceBase interface {
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
} }

View File

@ -0,0 +1,9 @@
//go:build !linux || android
package activity
import "net"
var (
listenIP = net.IP{127, 0, 0, 1}
)

View File

@ -0,0 +1,10 @@
//go:build !android
package activity
import "net"
var (
// use this ip to avoid eBPF proxy congestion
listenIP = net.IP{127, 0, 1, 1}
)

View File

@ -0,0 +1,106 @@
package activity
import (
"fmt"
"net"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
wgIface lazyconn.WGIface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
}
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener")
} else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
}
break
}
if n < 1 {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
break
}
if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
d.done.Unlock()
}
func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
if err := d.conn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
}
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
log.Errorf("failed to create activity listener on %s: %s", addr, err)
return nil, err
}
return conn, nil
}

View File

@ -0,0 +1,41 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@ -0,0 +1,95 @@
package activity
import (
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type Manager struct {
OnActivityChan chan peerid.ConnID
wgIface lazyconn.WGIface
peers map[peerid.ConnID]*Listener
done chan struct{}
mu sync.Mutex
}
func NewManager(wgIface lazyconn.WGIface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
done: make(chan struct{}),
}
return m
}
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
listener, ok := m.peers[peerConnID]
if !ok {
return
}
log.Debugf("removing activity listener")
delete(m.peers, peerConnID)
listener.Close()
}
func (m *Manager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
close(m.done)
for peerID, listener := range m.peers {
delete(m.peers, peerID)
listener.Close()
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {
m.mu.Unlock()
return
}
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
}
}

View File

@ -0,0 +1,162 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
PeerID string
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func TestManager_MonitorPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
}
case <-time.After(1 * time.Second):
}
}
func TestManager_RemovePeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
if err := trigger(addr); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
select {
case <-mgr.OnActivityChan:
t.Fatal("should not have active activity")
case <-time.After(1 * time.Second):
}
}
func TestManager_MultiPeerActivity(t *testing.T) {
mocWgInterface := &MocWGIface{}
peer1 := &MocPeer{
PeerID: "examplePublicKey1",
}
mgr := NewManager(mocWgInterface)
defer mgr.Close()
peerCfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
peer2 := &MocPeer{}
peerCfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
Log: log.WithField("peer", "examplePublicKey2"),
}
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
for i := 0; i < 2; i++ {
select {
case <-mgr.OnActivityChan:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}
}
func trigger(addr string) error {
// Create a connection to the destination UDP address and port
conn, err := net.Dial("udp", addr)
if err != nil {
return err
}
defer conn.Close()
// Write the bytes to the UDP connection
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,32 @@
/*
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
## Overview
The package includes a `Manager` component responsible for:
- Managing lazy connections activated on-demand
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
- Maintaining a list of excluded peers that should always have permanent connections
- Handling remote peer connection initiatives based on peer signaling
## Thread-Safe Operations
The `Manager` ensures thread safety across multiple operations, categorized by caller:
- **Engine (single goroutine)**:
- `AddPeer`: Adds a peer to the connection manager.
- `RemovePeer`: Removes a peer from the connection manager.
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
- **Connection Dispatcher (any peer routine)**:
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
- **Activity Manager**:
- `onPeerActivity`: Run peer.Open(context).
- **Inactivity Monitor**:
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
*/
package lazyconn

View File

@ -0,0 +1,26 @@
package lazyconn
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
)
func IsLazyConnEnabledByEnv() bool {
val := os.Getenv(EnvEnableLazyConn)
if val == "" {
return false
}
enabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
return false
}
return enabled
}

View File

@ -0,0 +1,70 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}

View File

@ -0,0 +1,156 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@ -0,0 +1,404 @@
package manager
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
const (
watcherActivity watcherType = iota
watcherInactivity
)
type watcherType int
type managedPeer struct {
peerCfg *lazyconn.PeerConfig
expectedWatcher watcherType
}
type Config struct {
InactivityThreshold *time.Duration
}
// Manager manages lazy connections
// It is responsible for:
// - Managing lazy connections activated on-demand
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
type Manager struct {
peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex
activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
cancel context.CancelFunc
onInactive chan peerid.ConnID
}
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service")
m := &Manager{
peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
onInactive: make(chan peerid.ConnID),
}
if config.InactivityThreshold != nil {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
m.inactivityThreshold = *config.InactivityThreshold
} else {
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
}
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m
}
// Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) {
defer m.close()
ctx, m.cancel = context.WithCancel(ctx)
for {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID)
}
}
}
// ExcludePeer marks peers for a permanent connection
// It removes peers from the managed list if they are added to the exclude list
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
added := make([]string, 0)
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
for _, peerCfg := range peerConfigs {
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
excludes[peerCfg.PublicKey] = peerCfg
}
// if a peer is newly added to the exclude list, remove from the managed peers list
for pubKey, peerCfg := range excludes {
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
continue
}
added = append(added, pubKey)
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
m.removePeer(pubKey)
}
// if a peer has been removed from exclude list then it should be added to the managed peers
for pubKey, peerCfg := range m.excludes {
if _, stillExcluded := excludes[pubKey]; stillExcluded {
continue
}
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(ctx, peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue
}
}
m.excludes = excludes
return added
}
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
peerCfg.Log.Debugf("adding peer to lazy connection manager")
_, exists := m.excludes[peerCfg.PublicKey]
if exists {
return true, nil
}
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return false, nil
}
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
return false, err
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherActivity,
}
return false, nil
}
// AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
for _, cfg := range peerCfg {
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
cfg.Log.Errorf("peer already managed")
continue
}
if err := m.addActivePeer(ctx, cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err
}
}
return nil
}
func (m *Manager) RemovePeer(peerID string) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.removePeer(peerID)
}
// ActivatePeer activates a peer connection when a signal message is received
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
cfg, ok := m.managedPeers[peerID]
if !ok {
return false
}
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
return false
}
// signal messages coming continuously after success activation, with this avoid the multiple activation
if mp.expectedWatcher == watcherInactivity {
return false
}
mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false
}
mp.peerCfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive)
return true
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed")
return nil
}
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg,
expectedWatcher: watcherInactivity,
}
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
go im.Start(ctx, m.onInactive)
return nil
}
func (m *Manager) removePeer(peerID string) {
cfg, ok := m.managedPeers[peerID]
if !ok {
return
}
cfg.Log.Infof("removing lazy peer")
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID)
}
func (m *Manager) close() {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
m.cancel()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
log.Infof("lazy connection manager closed")
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherActivity {
mp.peerCfg.Log.Warnf("ignore activity event")
return
}
mp.peerCfg.Log.Infof("detected peer activity")
mp.expectedWatcher = watcherInactivity
mp.peerCfg.Log.Infof("starting inactivity monitor")
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
log.Errorf("peer not found by id: %v", peerConnID)
return
}
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
// just in case free up
m.inactivityMonitors[peerConnID].PauseTimer()
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@ -0,0 +1,16 @@
package lazyconn
import (
"net/netip"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type PeerConfig struct {
PublicKey string
AllowedIPs []netip.Prefix
PeerConnID id.ConnID
Log *log.Entry
}

View File

@ -0,0 +1,41 @@
package lazyconn
import (
"strings"
"github.com/hashicorp/go-version"
)
var (
minVersion = version.Must(version.NewVersion("0.45.0"))
)
func IsSupported(agentVersion string) bool {
if agentVersion == "development" {
return true
}
// filter out versions like this: a6c5960, a7d5c522, d47be154
if !strings.Contains(agentVersion, ".") {
return false
}
normalizedVersion := normalizeVersion(agentVersion)
inputVer, err := version.NewVersion(normalizedVersion)
if err != nil {
return false
}
return inputVer.GreaterThanOrEqual(minVersion)
}
func normalizeVersion(version string) string {
// Remove prefixes like 'v' or 'a'
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
version = version[1:]
}
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
parts := strings.Split(version, "-")
return parts[0]
}

View File

@ -0,0 +1,31 @@
package lazyconn
import "testing"
func TestIsSupported(t *testing.T) {
tests := []struct {
version string
want bool
}{
{"development", true},
{"0.45.0", true},
{"v0.45.0", true},
{"0.45.1", true},
{"0.45.1-SNAPSHOT-559e6731", true},
{"v0.45.1-dev", true},
{"a7d5c522", false},
{"0.9.6", false},
{"0.9.6-SNAPSHOT", false},
{"0.9.6-SNAPSHOT-2033650", false},
{"meta_wt_version", false},
{"v0.31.1-dev", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.version, func(t *testing.T) {
if got := IsSupported(tt.version); got != tt.want {
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,14 @@
package lazyconn
import (
"net"
"net/netip"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}

View File

@ -19,7 +19,7 @@ import (
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("open routing socket: %v", err)
} }
defer func() { defer func() {
err := unix.Close(fd) err := unix.Close(fd)

View File

@ -13,7 +13,7 @@ import (
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx) routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err) return fmt.Errorf("create route monitor: %w", err)
} }
defer func() { defer func() {
if err := routeMonitor.Stop(); err != nil { if err := routeMonitor.Stop(); err != nil {
@ -38,35 +38,49 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
} }
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool { func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
intf := "<nil>" if intf := route.NextHop.Intf; intf != nil && isSoftInterface(intf.Name) {
if route.Interface != nil { log.Debugf("Network monitor: ignoring default route change for next hop with soft interface %s", route.NextHop)
intf = route.Interface.Name
if isSoftInterface(intf) {
log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
return false return false
} }
// TODO: for the empty nexthop ip (on-link), determine the family differently
nexthop := nexthopv4
if route.NextHop.IP.Is6() {
nexthop = nexthopv6
} }
switch route.Type { switch route.Type {
case systemops.RouteModified: case systemops.RouteModified, systemops.RouteAdded:
// TODO: get routing table to figure out if our route is affected for modified routes return handleRouteAddedOrModified(route, nexthop)
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
return true
case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
return true
}
case systemops.RouteDeleted: case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { return handleRouteDeleted(route, nexthop)
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
return true
}
} }
return false return false
} }
func handleRouteAddedOrModified(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool {
// For added/modified routes, we care about different next hops
if !nexthop.Equal(route.NextHop) {
action := "changed"
if route.Type == systemops.RouteAdded {
action = "added"
}
log.Infof("Network monitor: default route %s: via %s", action, route.NextHop)
return true
}
return false
}
func handleRouteDeleted(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool {
// For deleted routes, we care about our tracked next hop being deleted
if nexthop.Equal(route.NextHop) {
log.Infof("Network monitor: default route removed: via %s", route.NextHop)
return true
}
return false
}
func isSoftInterface(name string) bool { func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
} }

View File

@ -0,0 +1,404 @@
package networkmonitor
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func TestRouteChanged(t *testing.T) {
tests := []struct {
name string
route systemops.RouteUpdate
nexthopv4 systemops.Nexthop
nexthopv6 systemops.Nexthop
expected bool
}{
{
name: "soft interface should be ignored",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Name: "ISATAP-Interface", // isSoftInterface checks name
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.2"),
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: false,
},
{
name: "modified route with different v4 nexthop IP should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.2"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: true,
},
{
name: "modified route with same v4 nexthop (IP and Intf Index) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: false,
},
{
name: "added route with different v6 nexthop IP should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteAdded,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::2"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
expected: true,
},
{
name: "added route with same v6 nexthop (IP and Intf Index) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteAdded,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
expected: false,
},
{
name: "deleted route matching tracked v4 nexthop (IP and Intf Index) should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: true,
},
{
name: "deleted route not matching tracked v4 nexthop (different IP) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.3"), // Different IP
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{
Index: 1, Name: "eth0",
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
},
expected: false,
},
{
name: "modified v4 route with same IP, different Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "modified v4 route with same IP, one Intf nil, other non-nil should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: nil, // Intf is nil
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"}, // Tracked Intf is not nil
},
expected: true,
},
{
name: "added v4 route with same IP, different Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteAdded,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "deleted v4 route with same IP, different Intf Index should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{ // This is the route being deleted
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv4: systemops.Nexthop{ // This is our tracked nexthop
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
expected: false, // Because nexthopv4.Equal(route.NextHop) will be false
},
{
name: "modified v6 route with different IP, same Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::3"), // Different IP
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "modified v6 route with same IP, different Intf Index should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "modified v6 route with same IP, same Intf Index should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteModified,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: false,
},
{
name: "deleted v6 route matching tracked nexthop (IP and Intf Index) should return true",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: true,
},
{
name: "deleted v6 route not matching tracked nexthop (different IP) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::3"), // Different IP
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: false,
},
{
name: "deleted v6 route not matching tracked nexthop (same IP, different Intf Index) should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteDeleted,
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
NextHop: systemops.Nexthop{ // This is the route being deleted
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv6: systemops.Nexthop{ // This is our tracked nexthop
IP: netip.MustParseAddr("2001:db8::1"),
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
},
expected: false,
},
{
name: "unknown route type should return false",
route: systemops.RouteUpdate{
Type: systemops.RouteUpdateType(99), // Unknown type
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
NextHop: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.1"),
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
},
nexthopv4: systemops.Nexthop{
IP: netip.MustParseAddr("192.168.1.2"), // Different from route.NextHop
Intf: &net.Interface{Index: 1, Name: "eth0"},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := routeChanged(tt.route, tt.nexthopv4, tt.nexthopv6)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsSoftInterface(t *testing.T) {
tests := []struct {
name string
ifname string
expected bool
}{
{
name: "ISATAP interface should be detected",
ifname: "ISATAP tunnel adapter",
expected: true,
},
{
name: "lowercase soft interface should be detected",
ifname: "isatap.{14A5CF17-CA72-43EC-B4EA-B4B093641B7D}",
expected: true,
},
{
name: "Teredo interface should be detected",
ifname: "Teredo Tunneling Pseudo-Interface",
expected: true,
},
{
name: "regular interface should not be detected as soft",
ifname: "eth0",
expected: false,
},
{
name: "another regular interface should not be detected as soft",
ifname: "wlan0",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isSoftInterface(tt.ifname)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -118,9 +118,12 @@ func (nw *NetworkMonitor) Stop() {
} }
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
defer close(event)
for { for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event) if !errors.Is(err, context.Canceled) {
log.Errorf("Network monitor: failed to check for changes: %v", err)
}
return return
} }
// prevent blocking // prevent blocking

View File

@ -17,8 +17,12 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -26,32 +30,20 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case connPriorityNone:
return "None"
case connPriorityRelay:
return "PriorityRelay"
case connPriorityICETurn:
return "PriorityICETurn"
case connPriorityICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}
const ( const (
defaultWgKeepAlive = 25 * time.Second defaultWgKeepAlive = 25 * time.Second
connPriorityNone ConnPriority = 0
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 2
connPriorityICEP2P ConnPriority = 3
) )
type ServiceDependencies struct {
StatusRecorder *Status
Signaler *Signaler
IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher
}
type WgConfig struct { type WgConfig struct {
WgListenPort int WgListenPort int
RemoteKey string RemoteKey string
@ -76,6 +68,8 @@ type ConnConfig struct {
// LocalKey is a public key of a local peer // LocalKey is a public key of a local peer
LocalKey string LocalKey string
AgentVersion string
Timeout time.Duration Timeout time.Duration
WgConfig WgConfig WgConfig WgConfig
@ -89,22 +83,23 @@ type ConnConfig struct {
} }
type Conn struct { type Conn struct {
log *log.Entry Log *log.Entry
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
handshaker *Handshaker srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string) onDisconnected func(remotePeer string)
statusRelay *AtomicConnStatus statusRelay *worker.AtomicWorkerStatus
statusICE *AtomicConnStatus statusICE *worker.AtomicWorkerStatus
currentConnPriority ConnPriority currentConnPriority conntype.ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection opened bool // this flag is used to prevent close in case of not opened connection
workerICE *WorkerICE workerICE *WorkerICE
@ -120,9 +115,12 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
handshaker *Handshaker
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher
wg sync.WaitGroup
// debug purpose // debug purpose
dumpState *stateDump dumpState *stateDump
@ -130,91 +128,101 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
if len(config.WgConfig.AllowedIps) == 0 { if len(config.WgConfig.AllowedIps) == 0 {
return nil, fmt.Errorf("allowed IPs is empty") return nil, fmt.Errorf("allowed IPs is empty")
} }
ctx, ctxCancel := context.WithCancel(engineCtx)
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
log: connLog, Log: connLog,
ctx: ctx,
ctxCancel: ctxCancel,
config: config, config: config,
statusRecorder: statusRecorder, statusRecorder: services.StatusRecorder,
signaler: signaler, signaler: services.Signaler,
relayManager: relayManager, iFaceDiscover: services.IFaceDiscover,
statusRelay: NewAtomicConnStatus(), relayManager: services.RelayManager,
statusICE: NewAtomicConnStatus(), srWatcher: services.SrWatcher,
semaphore: semaphore, semaphore: services.Semaphore,
dumpState: newStateDump(config.Key, connLog, statusRecorder), peerConnDispatcher: services.PeerConnDispatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
} }
ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
if err != nil {
return nil, err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
go conn.handshaker.Listen()
go conn.dumpState.Start(ctx)
return conn, nil return conn, nil
} }
// Open opens connection to the remote peer // Open opens connection to the remote peer
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open(engineCtx context.Context) error {
conn.semaphore.Add(conn.ctx) conn.semaphore.Add(engineCtx)
conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.opened = true
if conn.opened {
conn.semaphore.Done(engineCtx)
return nil
}
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
conn.handshaker.Listen(conn.ctx)
}()
go conn.dumpState.Start(conn.ctx)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
ConnStatus: StatusDisconnected, ConnStatus: StatusConnecting,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
err := conn.statusRecorder.UpdatePeerState(peerState) if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
if err != nil { conn.Log.Warnf("error while updating the state err: %v", err)
conn.log.Warnf("error while updating the state err: %v", err)
} }
go conn.startHandshakeAndReconnect(conn.ctx) conn.wg.Add(1)
} go func() {
defer conn.wg.Done()
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { conn.waitInitialRandomSleepTime(conn.ctx)
defer conn.semaphore.Done(conn.ctx) conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx)
conn.dumpState.SendOffer() conn.dumpState.SendOffer()
err := conn.handshaker.sendOffer() if err := conn.handshaker.sendOffer(); err != nil {
if err != nil { conn.Log.Errorf("failed to send initial offer: %v", err)
conn.log.Errorf("failed to send initial offer: %v", err)
} }
go conn.guard.Start(ctx) conn.wg.Add(1)
go conn.listenGuardEvent(ctx) go func() {
conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.wg.Done()
}()
}()
conn.opened = true
return nil
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
@ -223,14 +231,14 @@ func (conn *Conn) Close() {
defer conn.wgWatcherWg.Wait() defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.log.Infof("close peer connection")
conn.ctxCancel()
if !conn.opened { if !conn.opened {
conn.log.Debugf("ignore close connection to peer") conn.Log.Debugf("ignore close connection to peer")
return return
} }
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
conn.workerICE.Close() conn.workerICE.Close()
@ -238,7 +246,7 @@ func (conn *Conn) Close() {
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn() err := conn.wgProxyRelay.CloseConn()
if err != nil { if err != nil {
conn.log.Errorf("failed to close wg proxy for relay: %v", err) conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
} }
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
@ -246,13 +254,13 @@ func (conn *Conn) Close() {
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
err := conn.wgProxyICE.CloseConn() err := conn.wgProxyICE.CloseConn()
if err != nil { if err != nil {
conn.log.Errorf("failed to close wg proxy for ice: %v", err) conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
} }
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID() conn.freeUpConnID()
@ -262,14 +270,16 @@ func (conn *Conn) Close() {
} }
conn.setStatusToDisconnected() conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed") conn.opened = false
conn.wg.Wait()
conn.Log.Infof("peer connection closed")
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) return conn.handshaker.OnRemoteAnswer(answer)
} }
@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.dumpState.RemoteOffer() conn.dumpState.RemoteOffer()
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) return conn.handshaker.OnRemoteOffer(offer)
} }
@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig return conn.config.WgConfig
} }
// Status returns current status of the Conn // IsConnected unit tests only
func (conn *Conn) Status() ConnStatus { // refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
return conn.evalStatus() return conn.currentConnPriority != conntype.None
} }
func (conn *Conn) GetKey() string { func (conn *Conn) GetKey() string {
return conn.config.Key return conn.config.Key
} }
func (conn *Conn) ConnID() id.ConnID {
return id.ConnID(conn)
}
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return return
} }
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) { if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
conn.log.Errorf("remote ICE connection is nil") conn.Log.Errorf("remote ICE connection is nil")
return return
} }
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade // this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
// todo consider to remove this check // todo consider to remove this check
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority) conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
conn.statusICE.Set(StatusConnected) conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
return return
} }
conn.log.Infof("set ICE to active connection") conn.Log.Infof("set ICE to active connection")
conn.dumpState.P2PConnected() conn.dumpState.P2PConnected()
var ( var (
@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return return
} }
ep = wgProxy.EndpointAddr() ep = wgProxy.EndpointAddr()
@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err) conn.Log.Errorf("Before add peer hook failed: %v", err)
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected) conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
} }
func (conn *Conn) onICEStateDisconnected() { func (conn *Conn) onICEStateDisconnected() {
@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
return return
} }
conn.log.Tracef("ICE connection state changed to disconnected") conn.Log.Tracef("ICE connection state changed to disconnected")
if conn.wgProxyICE != nil { if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil { if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
} }
} }
// switch back to relay connection // switch back to relay connection
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection") conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay() conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.Log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.wgWatcherWg.Add(1) conn.wgWatcherWg.Add(1)
@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done() defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
}() }()
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = conntype.Relay
} else { } else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = connPriorityNone conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
changed := conn.statusICE.Get() != StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
if changed { if changed {
conn.guard.SetICEConnDisconnected() conn.guard.SetICEConnDisconnected()
} }
conn.statusICE.Set(StatusDisconnected) conn.statusICE.SetDisconnected()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState) err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err) conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
} }
} }
@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil { if err := rci.relayedConn.Close(); err != nil {
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
} }
return return
} }
conn.dumpState.RelayConnected() conn.dumpState.RelayConnected()
conn.log.Debugf("Relay connection has been established, setup the WireGuard") conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy, err := conn.newProxy(rci.relayedConn) wgProxy, err := conn.newProxy(rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.isICEActive() { if conn.isICEActive() {
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err) conn.Log.Errorf("Before add peer hook failed: %v", err)
} }
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err) conn.Log.Warnf("Failed to close relay connection: %v", err)
} }
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround() wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = conntype.Relay
conn.statusRelay.Set(StatusConnected) conn.statusRelay.SetConnected()
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
} }
func (conn *Conn) onRelayDisconnected() { func (conn *Conn) onRelayDisconnected() {
@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
return return
} }
conn.log.Infof("relay connection is disconnected") conn.Log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == conntype.Relay {
conn.log.Infof("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.currentConnPriority = connPriorityNone conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != worker.StatusDisconnected
if changed { if changed {
conn.guard.SetRelayedConnDisconnected() conn.guard.SetRelayedConnDisconnected()
} }
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.SetDisconnected()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
} }
} }
func (conn *Conn) listenGuardEvent(ctx context.Context) { func (conn *Conn) onGuardEvent() {
for { conn.Log.Debugf("send offer to peer")
select {
case <-conn.guard.Reconnect:
conn.log.Infof("send offer to peer")
conn.dumpState.SendOffer() conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil { if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err) conn.Log.Errorf("failed to send offer: %v", err)
}
case <-ctx.Done():
return
}
} }
} }
@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
err := conn.statusRecorder.UpdatePeerRelayedState(peerState) err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err) conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
} }
} }
@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
err := conn.statusRecorder.UpdatePeerICEState(peerState) err := conn.statusRecorder.UpdatePeerICEState(peerState)
if err != nil { if err != nil {
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err) conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
} }
} }
func (conn *Conn) setStatusToDisconnected() { func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.SetDisconnected()
conn.statusICE.Set(StatusDisconnected) conn.statusICE.SetDisconnected()
conn.currentConnPriority = conntype.None
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: StatusDisconnected, ConnStatus: StatusIdle,
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
if err != nil { if err != nil {
// pretty common error because by that time Engine can already remove the peer and status won't be available. // pretty common error because by that time Engine can already remove the peer and status won't be available.
// todo rethink status updates // todo rethink status updates
conn.log.Debugf("error while updating peer's state, err: %v", err) conn.Log.Debugf("error while updating peer's state, err: %v", err)
} }
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
} }
} }
@ -656,27 +674,20 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
} }
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) { switch conn.currentConnPriority {
return false case conntype.Relay, conntype.ICETurn:
}
if conn.currentConnPriority == connPriorityICEP2P {
return false
}
return true return true
default:
return false
}
} }
func (conn *Conn) evalStatus() ConnStatus { func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected { if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
return StatusConnected return StatusConnected
} }
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
return StatusConnecting return StatusConnecting
}
return StatusDisconnected
} }
func (conn *Conn) isConnectedOnAllWay() (connected bool) { func (conn *Conn) isConnectedOnAllWay() (connected bool) {
@ -689,12 +700,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if conn.statusICE.Get() == StatusDisconnected { if conn.statusICE.Get() == worker.StatusDisconnected {
return false return false
} }
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() != StatusConnected { if conn.statusRelay.Get() == worker.StatusDisconnected {
return false return false
} }
} }
@ -716,7 +727,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" { if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil { if err := hook(conn.connIDRelay); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err) conn.Log.Errorf("After remove peer hook failed: %v", err)
} }
} }
conn.connIDRelay = "" conn.connIDRelay = ""
@ -725,7 +736,7 @@ func (conn *Conn) freeUpConnID() {
if conn.connIDICE != "" { if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil { if err := hook(conn.connIDICE); err != nil {
conn.log.Errorf("After remove peer hook failed: %v", err) conn.Log.Errorf("After remove peer hook failed: %v", err)
} }
} }
conn.connIDICE = "" conn.connIDICE = ""
@ -733,7 +744,7 @@ func (conn *Conn) freeUpConnID() {
} }
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(), IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
Port: conn.config.WgConfig.WgListenPort, Port: conn.config.WgConfig.WgListenPort,
@ -741,18 +752,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
wgProxy := conn.config.WgConfig.WgInterface.GetProxy() wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil { if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return nil, err return nil, err
} }
return wgProxy, nil return wgProxy, nil
} }
func (conn *Conn) isReadyToUpgrade() bool { func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
} }
func (conn *Conn) isICEActive() bool { func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
} }
func (conn *Conn) removeWgPeer() error { func (conn *Conn) removeWgPeer() error {
@ -760,10 +771,10 @@ func (conn *Conn) removeWgPeer() error {
} }
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err) conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil { if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil { if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr) conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
} }
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@ -773,16 +784,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
func (conn *Conn) logTraceConnState() { func (conn *Conn) logTraceConnState() {
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
} else { } else {
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE) conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
} }
} }
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil { if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
} }
} }
conn.wgProxyRelay = proxy conn.wgProxyRelay = proxy
@ -793,6 +804,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr() return conn.config.WgConfig.AllowedIps[0].Addr()
} }
func (conn *Conn) AgentVersionString() string {
return conn.config.AgentVersion
}
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key { func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
if conn.config.RosenpassConfig.PubKey == nil { if conn.config.RosenpassConfig.PubKey == nil {
return conn.config.WgConfig.PreSharedKey return conn.config.WgConfig.PreSharedKey
@ -804,7 +819,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
determKey, err := conn.rosenpassDetermKey() determKey, err := conn.rosenpassDetermKey()
if err != nil { if err != nil {
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err) conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey return conn.config.WgConfig.PreSharedKey
} }

View File

@ -1,58 +1,29 @@
package peer package peer
import ( import (
"sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
// StatusConnected indicate the peer is in connected state // StatusIdle indicate the peer is in disconnected state
StatusConnected ConnStatus = iota StatusIdle ConnStatus = iota
// StatusConnecting indicate the peer is in connecting state // StatusConnecting indicate the peer is in connecting state
StatusConnecting StatusConnecting
// StatusDisconnected indicate the peer is in disconnected state // StatusConnected indicate the peer is in connected state
StatusDisconnected StatusConnected
) )
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
func (s ConnStatus) String() string { func (s ConnStatus) String() string {
switch s { switch s {
case StatusConnecting: case StatusConnecting:
return "Connecting" return "Connecting"
case StatusConnected: case StatusConnected:
return "Connected" return "Connected"
case StatusDisconnected: case StatusIdle:
return "Disconnected" return "Idle"
default: default:
log.Errorf("unknown status: %d", s) log.Errorf("unknown status: %d", s)
return "INVALID_PEER_CONNECTION_STATUS" return "INVALID_PEER_CONNECTION_STATUS"

View File

@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
want string want string
}{ }{
{"StatusConnected", StatusConnected, "Connected"}, {"StatusConnected", StatusConnected, "Connected"},
{"StatusDisconnected", StatusDisconnected, "Disconnected"}, {"StatusIdle", StatusIdle, "Idle"},
{"StatusConnecting", StatusConnecting, "Connecting"}, {"StatusConnecting", StatusConnecting, "Connecting"},
} }
@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
assert.Equal(t, got, table.want, "they should be equal") assert.Equal(t, got, table.want, "they should be equal")
}) })
} }
} }

View File

@ -1,7 +1,6 @@
package peer package peer
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@ -11,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@ -18,6 +18,8 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var testDispatcher = dispatcher.NewConnectionDispatcher()
var connConf = ConnConfig{ var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
sd := ServiceDependencies{
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher,
}
conn, err := NewConn(connConf, sd)
if err != nil { if err != nil {
return return
} }
@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
tables := []struct {
name string
statusIce ConnStatus
statusRelay ConnStatus
want ConnStatus
}{
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
func TestConn_presharedKey(t *testing.T) { func TestConn_presharedKey(t *testing.T) {
conn1 := Conn{ conn1 := Conn{

View File

@ -0,0 +1,29 @@
package conntype
import (
"fmt"
)
const (
None ConnPriority = 0
Relay ConnPriority = 1
ICETurn ConnPriority = 2
ICEP2P ConnPriority = 3
)
type ConnPriority int
func (cp ConnPriority) String() string {
switch cp {
case None:
return "None"
case Relay:
return "PriorityRelay"
case ICETurn:
return "PriorityICETurn"
case ICEP2P:
return "PriorityICEP2P"
default:
return fmt.Sprintf("ConnPriority(%d)", cp)
}
}

View File

@ -0,0 +1,52 @@
package dispatcher
import (
"sync"
"github.com/netbirdio/netbird/client/internal/peer/id"
)
type ConnectionListener struct {
OnConnected func(peerID id.ConnID)
OnDisconnected func(peerID id.ConnID)
}
type ConnectionDispatcher struct {
listeners map[*ConnectionListener]struct{}
mu sync.Mutex
}
func NewConnectionDispatcher() *ConnectionDispatcher {
return &ConnectionDispatcher{
listeners: make(map[*ConnectionListener]struct{}),
}
}
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
e.listeners[listener] = struct{}{}
}
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
e.mu.Lock()
defer e.mu.Unlock()
delete(e.listeners, listener)
}
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnConnected(peerConnID)
}
}
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
e.mu.Lock()
defer e.mu.Unlock()
for listener := range e.listeners {
listener.OnDisconnected(peerConnID)
}
}

View File

@ -8,10 +8,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
reconnectMaxElapsedTime = 30 * time.Minute
)
type isConnectedFunc func() bool type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic. // Guard is responsible for the reconnection logic.
@ -25,7 +21,6 @@ type isConnectedFunc func() bool
type Guard struct { type Guard struct {
Reconnect chan struct{} Reconnect chan struct{}
log *log.Entry log *log.Entry
isController bool
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
srWatcher *SRWatcher srWatcher *SRWatcher
@ -33,11 +28,10 @@ type Guard struct {
iCEConnDisconnected chan struct{} iCEConnDisconnected chan struct{}
} }
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
Reconnect: make(chan struct{}, 1), Reconnect: make(chan struct{}, 1),
log: log, log: log,
isController: isController,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
timeout: timeout, timeout: timeout,
srWatcher: srWatcher, srWatcher: srWatcher,
@ -46,12 +40,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
} }
} }
func (g *Guard) Start(ctx context.Context) { func (g *Guard) Start(ctx context.Context, eventCallback func()) {
if g.isController { g.reconnectLoopWithRetry(ctx, eventCallback)
g.reconnectLoopWithRetry(ctx)
} else {
g.listenForDisconnectEvents(ctx)
}
} }
func (g *Guard) SetRelayedConnDisconnected() { func (g *Guard) SetRelayedConnDisconnected() {
@ -68,9 +58,9 @@ func (g *Guard) SetICEConnDisconnected() {
} }
} }
// reconnectLoopWithRetry periodically check (max 30 min) the connection status. // reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported // Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx) waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
@ -93,7 +83,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
} }
if !g.isConnectedOnAllWay() { if !g.isConnectedOnAllWay() {
g.triggerOfferSending() callback()
} }
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
@ -121,39 +111,12 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
} }
} }
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
// mean that to switch to it. We always force to use the higher priority connection.
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
g.log.Infof("start listen for reconnect events...")
for {
select {
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, triggering reconnect")
g.triggerOfferSending()
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE state changed, try to send new offer")
g.triggerOfferSending()
case <-srReconnectedChan:
g.triggerOfferSending()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
return
}
}
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1, RandomizationFactor: 0.1,
Multiplier: 2, Multiplier: 2,
MaxInterval: g.timeout, MaxInterval: g.timeout,
MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
@ -164,13 +127,6 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker return ticker
} }
func (g *Guard) triggerOfferSending() {
select {
case g.Reconnect <- struct{}{}:
default:
}
}
// Give chance to the peer to establish the initial connection. // Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer // With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) { func waitForInitialConnectionTry(ctx context.Context) {

View File

@ -43,7 +43,6 @@ type OfferAnswer struct {
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
ctx context.Context
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
signaler *Signaler signaler *Signaler
@ -57,9 +56,8 @@ type Handshaker struct {
remoteAnswerCh chan OfferAnswer remoteAnswerCh chan OfferAnswer
} }
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker { func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
return &Handshaker{ return &Handshaker{
ctx: ctx,
log: log, log: log,
config: config, config: config,
signaler: signaler, signaler: signaler,
@ -74,10 +72,10 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
h.onNewOfferListeners = append(h.onNewOfferListeners, offer) h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
} }
func (h *Handshaker) Listen() { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
h.log.Info("wait for remote offer confirmation") h.log.Info("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation() remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
if err != nil { if err != nil {
var connectionClosedError *ConnectionClosedError var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) { if errors.As(err, &connectionClosedError) {
@ -127,7 +125,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
} }
} }
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) { func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed // received confirmation from the remote peer -> ready to proceed
@ -137,7 +135,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
return &remoteOfferAnswer, nil return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil return &remoteOfferAnswer, nil
case <-h.ctx.Done(): case <-ctx.Done():
// closed externally // closed externally
return nil, NewConnectionClosedError(h.config.Key) return nil, NewConnectionClosedError(h.config.Key)
} }

View File

@ -0,0 +1,5 @@
package id
import "unsafe"
type ConnID unsafe.Pointer

View File

@ -15,7 +15,7 @@ import (
type WGIface interface { type WGIface interface {
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
} }

View File

@ -143,6 +143,7 @@ type FullStatus struct {
Relays []relay.ProbeResult Relays []relay.ProbeResult
NSGroupStates []NSGroupState NSGroupStates []NSGroupState
NumOfForwardingRules int NumOfForwardingRules int
LazyConnectionEnabled bool
} }
// Status holds a state of peers, signal, management connections and relays // Status holds a state of peers, signal, management connections and relays
@ -164,6 +165,7 @@ type Status struct {
rosenpassPermissive bool rosenpassPermissive bool
nsGroupStates []NSGroupState nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
lazyConnectionEnabled bool
// To reduce the number of notification invocation this bool will be true when need to call the notification // To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@ -219,7 +221,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
} }
// AddPeer adds peer to Daemon status map // AddPeer adds peer to Daemon status map
func (d *Status) AddPeer(peerPubKey string, fqdn string) error { func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@ -229,7 +231,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
} }
d.peers[peerPubKey] = State{ d.peers[peerPubKey] = State{
PubKey: peerPubKey, PubKey: peerPubKey,
ConnStatus: StatusDisconnected, IP: ip,
ConnStatus: StatusIdle,
FQDN: fqdn, FQDN: fqdn,
Mux: new(sync.RWMutex), Mux: new(sync.RWMutex),
} }
@ -511,9 +514,9 @@ func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch { switch {
case receivedConnStatus == StatusConnecting: case receivedConnStatus == StatusConnecting:
return true return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting: case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
return true return true
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected: case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
return curr.IP != "" return curr.IP != ""
default: default:
return false return false
@ -689,6 +692,12 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
d.rosenpassEnabled = rosenpassEnabled d.rosenpassEnabled = rosenpassEnabled
} }
func (d *Status) UpdateLazyConnection(enabled bool) {
d.mux.Lock()
defer d.mux.Unlock()
d.lazyConnectionEnabled = enabled
}
// MarkSignalDisconnected sets SignalState to disconnected // MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) { func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock() d.mux.Lock()
@ -761,6 +770,12 @@ func (d *Status) GetRosenpassState() RosenpassState {
} }
} }
func (d *Status) GetLazyConnection() bool {
d.mux.Lock()
defer d.mux.Unlock()
return d.lazyConnectionEnabled
}
func (d *Status) GetManagementState() ManagementState { func (d *Status) GetManagementState() ManagementState {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@ -878,6 +893,7 @@ func (d *Status) GetFullStatus() FullStatus {
RosenpassState: d.GetRosenpassState(), RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(), NSGroupStates: d.GetDNSStates(),
NumOfForwardingRules: len(d.ForwardingRules()), NumOfForwardingRules: len(d.ForwardingRules()),
LazyConnectionEnabled: d.GetLazyConnection(),
} }
d.mux.Lock() d.mux.Lock()

View File

@ -10,22 +10,24 @@ import (
func TestAddPeer(t *testing.T) { func TestAddPeer(t *testing.T) {
key := "abc" key := "abc"
ip := "100.108.254.1"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key, "abc.netbird") err := status.AddPeer(key, "abc.netbird", ip)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
_, exists := status.peers[key] _, exists := status.peers[key]
assert.True(t, exists, "value was found") assert.True(t, exists, "value was found")
err = status.AddPeer(key, "abc.netbird") err = status.AddPeer(key, "abc.netbird", ip)
assert.Error(t, err, "should return error on duplicate") assert.Error(t, err, "should return error on duplicate")
} }
func TestGetPeer(t *testing.T) { func TestGetPeer(t *testing.T) {
key := "abc" key := "abc"
ip := "100.108.254.1"
status := NewRecorder("https://mgm") status := NewRecorder("https://mgm")
err := status.AddPeer(key, "abc.netbird") err := status.AddPeer(key, "abc.netbird", ip)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
peerStatus, err := status.GetPeer(key) peerStatus, err := status.GetPeer(key)

View File

@ -2,6 +2,7 @@ package peer
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@ -20,7 +21,7 @@ var (
) )
type WGInterfaceStater interface { type WGInterfaceStater interface {
GetStats(key string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
} }
type WGWatcher struct { type WGWatcher struct {
@ -146,9 +147,13 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
} }
func (w *WGWatcher) wgState() (time.Time, error) { func (w *WGWatcher) wgState() (time.Time, error) {
wgState, err := w.wgIfaceStater.GetStats(w.peerKey) wgStates, err := w.wgIfaceStater.GetStats()
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
wgState, ok := wgStates[w.peerKey]
if !ok {
return time.Time{}, fmt.Errorf("peer %s not found in WireGuard endpoints", w.peerKey)
}
return wgState.LastHandshake, nil return wgState.LastHandshake, nil
} }

View File

@ -11,26 +11,11 @@ import (
) )
type MocWgIface struct { type MocWgIface struct {
initial bool
lastHandshake time.Time
stop bool stop bool
} }
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) { func (m *MocWgIface) GetStats() (map[string]configurer.WGStats, error) {
if !m.initial { return map[string]configurer.WGStats{}, nil
m.initial = true
return configurer.WGStats{}, nil
}
if !m.stop {
m.lastHandshake = time.Now()
}
stats := configurer.WGStats{
LastHandshake: m.lastHandshake,
}
return stats, nil
} }
func (m *MocWgIface) disconnect() { func (m *MocWgIface) disconnect() {

View File

@ -0,0 +1,55 @@
package worker
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
const (
StatusDisconnected Status = iota
StatusConnected
)
type Status int32
func (s Status) String() string {
switch s {
case StatusDisconnected:
return "Disconnected"
case StatusConnected:
return "Connected"
default:
log.Errorf("unknown status: %d", s)
return "unknown"
}
}
// AtomicWorkerStatus is a thread-safe wrapper for worker status
type AtomicWorkerStatus struct {
status atomic.Int32
}
func NewAtomicStatus() *AtomicWorkerStatus {
acs := &AtomicWorkerStatus{}
acs.SetDisconnected()
return acs
}
// Get returns the current connection status
func (acs *AtomicWorkerStatus) Get() Status {
return Status(acs.status.Load())
}
func (acs *AtomicWorkerStatus) SetConnected() {
acs.status.Store(int32(StatusConnected))
}
func (acs *AtomicWorkerStatus) SetDisconnected() {
acs.status.Store(int32(StatusDisconnected))
}
// String returns the string representation of the current status
func (acs *AtomicWorkerStatus) String() string {
return acs.Get().String()
}

View File

@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -397,10 +398,10 @@ func isRelayed(pair *ice.CandidatePair) bool {
return false return false
} }
func selectedPriority(pair *ice.CandidatePair) ConnPriority { func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
if isRelayed(pair) { if isRelayed(pair) {
return connPriorityICETurn return conntype.ICETurn
} else { } else {
return connPriorityICEP2P return conntype.ICEP2P
} }
} }

View File

@ -1,6 +1,7 @@
package peerstore package peerstore
import ( import (
"context"
"net/netip" "net/netip"
"sync" "sync"
@ -79,6 +80,32 @@ func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
return p, true return p, true
} }
func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnClose(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
p.Close()
}
func (s *Store) PeersPubKey() []string { func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock() s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock() defer s.peerConnsMu.RUnlock()

View File

@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/client/common"
) )
// PKCEAuthorizationFlow represents PKCE Authorization Flow information // PKCEAuthorizationFlow represents PKCE Authorization Flow information
@ -41,6 +42,8 @@ type PKCEAuthProviderConfig struct {
ClientCertPair *tls.Certificate ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login // DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
} }
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
@ -100,6 +103,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert, ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(), DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
}, },
} }

View File

@ -3,7 +3,6 @@ package iface
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@ -18,5 +17,4 @@ type wgIfaceBase interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
} }

View File

@ -32,6 +32,10 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String()) nets = append(nets, r.Network.String())
} }
sort.Strings(nets) sort.Strings(nets)

View File

@ -1,6 +1,7 @@
package systemops package systemops
import ( import (
"fmt"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
@ -15,6 +16,20 @@ type Nexthop struct {
Intf *net.Interface Intf *net.Interface
} }
// Equal checks if two nexthops are equal.
func (n Nexthop) Equal(other Nexthop) bool {
return n.IP == other.IP && (n.Intf == nil && other.Intf == nil ||
n.Intf != nil && other.Intf != nil && n.Intf.Index == other.Intf.Index)
}
// String returns a string representation of the nexthop.
func (n Nexthop) String() string {
if n.Intf == nil {
return n.IP.String()
}
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct { type SysOps struct {

View File

@ -33,8 +33,7 @@ type RouteUpdateType int
type RouteUpdate struct { type RouteUpdate struct {
Type RouteUpdateType Type RouteUpdateType
Destination netip.Prefix Destination netip.Prefix
NextHop netip.Addr NextHop Nexthop
Interface *net.Interface
} }
// RouteMonitor provides a way to monitor changes in the routing table. // RouteMonitor provides a way to monitor changes in the routing table.
@ -231,15 +230,15 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
intf, err := net.InterfaceByIndex(idx) intf, err := net.InterfaceByIndex(idx)
if err != nil { if err != nil {
log.Warnf("failed to get interface name for index %d: %v", idx, err) log.Warnf("failed to get interface name for index %d: %v", idx, err)
update.Interface = &net.Interface{ update.NextHop.Intf = &net.Interface{
Index: idx, Index: idx,
} }
} else { } else {
update.Interface = intf update.NextHop.Intf = intf
} }
} }
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.NextHop.Intf)
dest := parseIPPrefix(row.DestinationPrefix, idx) dest := parseIPPrefix(row.DestinationPrefix, idx)
if !dest.Addr().IsValid() { if !dest.Addr().IsValid() {
return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row) return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row)
@ -262,7 +261,7 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
update.Type = updateType update.Type = updateType
update.Destination = dest update.Destination = dest
update.NextHop = nexthop update.NextHop.IP = nexthop
return update, nil return update, nil
} }

File diff suppressed because it is too large Load Diff

View File

@ -94,7 +94,7 @@ message LoginRequest {
bytes customDNSAddress = 7; bytes customDNSAddress = 7;
bool isLinuxDesktopClient = 8; bool isUnixDesktopClient = 8;
string hostname = 9; string hostname = 9;
@ -134,6 +134,7 @@ message LoginRequest {
// omits initialized empty slices due to omitempty tags // omits initialized empty slices due to omitempty tags
bool cleanDNSLabels = 27; bool cleanDNSLabels = 27;
optional bool lazyConnectionEnabled = 28;
} }
message LoginResponse { message LoginResponse {
@ -274,6 +275,8 @@ message FullStatus {
int32 NumberOfForwardingRules = 8; int32 NumberOfForwardingRules = 8;
repeated SystemEvent events = 7; repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
} }
// Networks // Networks

View File

@ -139,6 +139,7 @@ func (s *Server) Start() error {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
if s.sessionWatcher == nil { if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder) s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
@ -417,6 +418,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.DisableNotifications = msg.DisableNotifications s.latestConfigInput.DisableNotifications = msg.DisableNotifications
} }
if msg.LazyConnectionEnabled != nil {
inputConfig.LazyConnectionEnabled = msg.LazyConnectionEnabled
s.latestConfigInput.LazyConnectionEnabled = msg.LazyConnectionEnabled
}
s.mutex.Unlock() s.mutex.Unlock()
if msg.OptionalPreSharedKey != nil { if msg.OptionalPreSharedKey != nil {
@ -446,7 +452,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, err return nil, err
@ -804,6 +810,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers { for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{ pbPeerState := &proto.PeerState{

View File

@ -97,6 +97,7 @@ type OutputOverview struct {
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"` NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
Events []SystemEventOutput `json:"events" yaml:"events"` Events []SystemEventOutput `json:"events" yaml:"events"`
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
} }
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview { func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
@ -136,6 +137,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()), NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
Events: mapEvents(pbFullStatus.GetEvents()), Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
} }
if anon { if anon {
@ -206,7 +208,7 @@ func mapPeers(
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
continue continue
} }
if isPeerConnected { if isPeerConnected {
@ -384,6 +386,11 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
} }
} }
lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true"
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS goos := runtime.GOOS
@ -405,6 +412,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"NetBird IP: %s\n"+ "NetBird IP: %s\n"+
"Interface type: %s\n"+ "Interface type: %s\n"+
"Quantum resistance: %s\n"+ "Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"Networks: %s\n"+ "Networks: %s\n"+
"Forwarding rules: %d\n"+ "Forwarding rules: %d\n"+
"Peers count: %s\n", "Peers count: %s\n",
@ -419,6 +427,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
networks, networks,
overview.NumberOfForwardingRules, overview.NumberOfForwardingRules,
peersCountString, peersCountString,
@ -533,23 +542,13 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
return peersString return peersString
} }
func skipDetailByFilters( func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
peerState *proto.PeerState,
isConnected bool,
statusFilter string,
prefixNamesFilter []string,
prefixNamesFilterMap map[string]struct{},
ipsFilter map[string]struct{},
) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := true nameEval := true
if statusFilter != "" { if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter) if !strings.EqualFold(peerStatus, statusFilter) {
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true statusEval = true
} }
} }

View File

@ -383,7 +383,8 @@ func TestParsingToJSON(t *testing.T) {
"error": "timeout" "error": "timeout"
} }
], ],
"events": [] "events": [],
"lazyConnectionEnabled": false
}` }`
// @formatter:on // @formatter:on
@ -484,6 +485,7 @@ dnsServers:
enabled: false enabled: false
error: timeout error: timeout
events: [] events: []
lazyConnectionEnabled: false
` `
assert.Equal(t, expectedYAML, yaml) assert.Equal(t, expectedYAML, yaml)
@ -548,6 +550,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16 NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Lazy connection: false
Networks: 10.10.0.0/24 Networks: 10.10.0.0/24
Forwarding rules: 0 Forwarding rules: 0
Peers count: 2/2 Connected Peers count: 2/2 Connected
@ -570,6 +573,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16 NetBird IP: 192.168.178.100/16
Interface type: Kernel Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Lazy connection: false
Networks: 10.10.0.0/24 Networks: 10.10.0.0/24
Forwarding rules: 0 Forwarding rules: 0
Peers count: 2/2 Connected Peers count: 2/2 Connected

View File

@ -62,6 +62,8 @@ func main() {
return return
} }
logFile = file logFile = file
} else {
_ = util.InitLog("trace", "console")
} }
// Create the Fyne application. // Create the Fyne application.
@ -192,6 +194,7 @@ type serviceClient struct {
mAllowSSH *systray.MenuItem mAllowSSH *systray.MenuItem
mAutoConnect *systray.MenuItem mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
mNotifications *systray.MenuItem mNotifications *systray.MenuItem
mAdvancedSettings *systray.MenuItem mAdvancedSettings *systray.MenuItem
mCreateDebugBundle *systray.MenuItem mCreateDebugBundle *systray.MenuItem
@ -385,7 +388,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
ManagementUrl: iMngURL, ManagementUrl: iMngURL,
AdminURL: iAdminURL, AdminURL: iAdminURL,
IsLinuxDesktopClient: runtime.GOOS == "linux", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
RosenpassPermissive: &s.sRosenpassPermissive.Checked, RosenpassPermissive: &s.sRosenpassPermissive.Checked,
InterfaceName: &s.iInterfaceName.Text, InterfaceName: &s.iInterfaceName.Text,
WireguardPort: &port, WireguardPort: &port,
@ -415,7 +418,7 @@ func (s *serviceClient) login() error {
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
}) })
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) log.Errorf("login to management URL with: %v", err)
@ -631,6 +634,7 @@ func (s *serviceClient) onTrayReady() {
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false) s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false) s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false) s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable lazy connection", lazyConnMenuDescr, false)
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false) s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr) s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr) s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
@ -693,7 +697,10 @@ func (s *serviceClient) onTrayReady() {
go s.eventManager.Start(s.ctx) go s.eventManager.Start(s.ctx)
go func() { go s.listenEvents()
}
func (s *serviceClient) listenEvents() {
for { for {
select { select {
case <-s.mUp.ClickedCh: case <-s.mUp.ClickedCh:
@ -743,6 +750,15 @@ func (s *serviceClient) onTrayReady() {
if err := s.updateConfig(); err != nil { if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
} }
case <-s.mLazyConnEnabled.ClickedCh:
if s.mLazyConnEnabled.Checked() {
s.mLazyConnEnabled.Uncheck()
} else {
s.mLazyConnEnabled.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
}
case <-s.mAdvancedSettings.ClickedCh: case <-s.mAdvancedSettings.ClickedCh:
s.mAdvancedSettings.Disable() s.mAdvancedSettings.Disable()
go func() { go func() {
@ -788,9 +804,7 @@ func (s *serviceClient) onTrayReady() {
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
} }
} }
} }
}()
} }
func (s *serviceClient) runSelfCommand(command, arg string) { func (s *serviceClient) runSelfCommand(command, arg string) {
@ -1022,13 +1036,15 @@ func (s *serviceClient) updateConfig() error {
sshAllowed := s.mAllowSSH.Checked() sshAllowed := s.mAllowSSH.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked() rosenpassEnabled := s.mEnableRosenpass.Checked()
notificationsDisabled := !s.mNotifications.Checked() notificationsDisabled := !s.mNotifications.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ServerSSHAllowed: &sshAllowed, ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled, RosenpassEnabled: &rosenpassEnabled,
DisableAutoConnect: &disableAutoStart, DisableAutoConnect: &disableAutoStart,
DisableNotifications: &notificationsDisabled, DisableNotifications: &notificationsDisabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
} }
if err := s.restartClient(&loginRequest); err != nil { if err := s.restartClient(&loginRequest); err != nil {

View File

@ -6,6 +6,7 @@ const (
allowSSHMenuDescr = "Allow SSH connections" allowSSHMenuDescr = "Allow SSH connections"
autoConnectMenuDescr = "Connect automatically when the service starts" autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass" quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connect"
notificationsMenuDescr = "Enable notifications" notificationsMenuDescr = "Enable notifications"
advancedSettingsMenuDescr = "Advanced settings of the application" advancedSettingsMenuDescr = "Advanced settings of the application"
debugBundleMenuDescr = "Create and open debug information bundle" debugBundleMenuDescr = "Create and open debug information bundle"

View File

@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
func (s SimpleRecord) Len() uint16 { func (s SimpleRecord) Len() uint16 {
emptyString := s.RData == "" emptyString := s.RData == ""
switch s.Type { switch s.Type {
case 1: case int(dns.TypeA):
if emptyString { if emptyString {
return 0 return 0
} }
return net.IPv4len return net.IPv4len
case 5: case int(dns.TypeCNAME):
if emptyString || s.RData == "." { if emptyString || s.RData == "." {
return 1 return 1
} }
return uint16(len(s.RData) + 1) return uint16(len(s.RData) + 1)
case 28: case int(dns.TypeAAAA):
if emptyString { if emptyString {
return 0 return 0
} }

4
go.mod
View File

@ -59,13 +59,12 @@ require (
github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/go-version v1.6.0
github.com/libdns/route53 v1.5.0 github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-netroute v0.2.1
github.com/mattn/go-sqlite3 v1.14.22
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203 github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
@ -195,6 +194,7 @@ require (
github.com/libdns/libdns v0.2.2 // indirect github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/mholt/acmez/v2 v2.0.1 // indirect

4
go.sum
View File

@ -507,8 +507,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs= github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=

View File

@ -59,6 +59,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"} NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false} NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false} NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard # Dashboard
@ -122,6 +123,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
export NETBIRD_AUTH_PKCE_LOGIN_FLAG
export NETBIRD_AUTH_PKCE_AUDIENCE export NETBIRD_AUTH_PKCE_AUDIENCE
export NETBIRD_DASH_AUTH_USE_AUDIENCE export NETBIRD_DASH_AUTH_USE_AUDIENCE
export NETBIRD_DASH_AUTH_AUDIENCE export NETBIRD_DASH_AUTH_AUDIENCE

View File

@ -95,7 +95,8 @@
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES", "Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS], "RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN, "UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN "DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
} }
} }
} }

View File

@ -28,3 +28,4 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4 NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
NETBIRD_RELAY_PORT=33445 NETBIRD_RELAY_PORT=33445
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0

View File

@ -0,0 +1,19 @@
package common
// LoginFlag introduces additional login flags to the PKCE authorization request
type LoginFlag uint8
const (
// LoginFlagPrompt adds prompt=login to the authorization request
LoginFlagPrompt LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0
)
func (l LoginFlag) IsPromptLogin() bool {
return l == LoginFlagPrompt
}
func (l LoginFlag) IsMaxAge0Login() bool {
return l == LoginFlagMaxAge0
}

View File

@ -16,11 +16,13 @@ type AccountsAPI struct {
// List list all accounts, only returns one account always // List list all accounts, only returns one account always
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts // See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) { func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/accounts", nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Account](resp) ret, err := parseResponse[[]api.Account](resp)
return ret, err return ret, err
} }
@ -32,11 +34,13 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := a.c.newRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes)) resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[api.Account](resp) ret, err := parseResponse[api.Account](resp)
return &ret, err return &ret, err
} }
@ -44,11 +48,13 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
// Delete delete account // Delete delete account
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account // See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error { func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil) resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
if err != nil { if err != nil {
return err return err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
return nil return nil
} }

View File

@ -14,6 +14,7 @@ import (
type Client struct { type Client struct {
managementURL string managementURL string
authHeader string authHeader string
httpClient HttpClient
// Accounts NetBird account APIs // Accounts NetBird account APIs
// see more: https://docs.netbird.io/api/resources/accounts // see more: https://docs.netbird.io/api/resources/accounts
@ -70,20 +71,29 @@ type Client struct {
// New initialize new Client instance using PAT token // New initialize new Client instance using PAT token
func New(managementURL, token string) *Client { func New(managementURL, token string) *Client {
client := &Client{ return NewWithOptions(
managementURL: managementURL, WithManagementURL(managementURL),
authHeader: "Token " + token, WithPAT(token),
} )
client.initialize()
return client
} }
// NewWithBearerToken initialize new Client instance using Bearer token type // NewWithBearerToken initialize new Client instance using Bearer token type
func NewWithBearerToken(managementURL, token string) *Client { func NewWithBearerToken(managementURL, token string) *Client {
return NewWithOptions(
WithManagementURL(managementURL),
WithBearerToken(token),
)
}
func NewWithOptions(opts ...option) *Client {
client := &Client{ client := &Client{
managementURL: managementURL, httpClient: http.DefaultClient,
authHeader: "Bearer " + token,
} }
for _, option := range opts {
option(client)
}
client.initialize() client.initialize()
return client return client
} }
@ -104,7 +114,7 @@ func (c *Client) initialize() {
c.Events = &EventsAPI{c} c.Events = &EventsAPI{c}
} }
func (c *Client) newRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body) req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
if err != nil { if err != nil {
return nil, err return nil, err
@ -116,7 +126,7 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
} }
resp, err := http.DefaultClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -16,11 +16,13 @@ type DNSAPI struct {
// ListNameserverGroups list all nameserver groups // ListNameserverGroups list all nameserver groups
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups // See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) { func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers", nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NameserverGroup](resp) ret, err := parseResponse[[]api.NameserverGroup](resp)
return ret, err return ret, err
} }
@ -28,11 +30,13 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
// GetNameserverGroup get nameserver group info // GetNameserverGroup get nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group // See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) { func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp) ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err return &ret, err
} }
@ -44,11 +48,13 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := a.c.newRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes)) resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp) ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err return &ret, err
} }
@ -60,11 +66,13 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes)) resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp) ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err return &ret, err
} }
@ -72,11 +80,13 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
// DeleteNameserverGroup delete nameserver group // DeleteNameserverGroup delete nameserver group
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group // See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error { func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
resp, err := a.c.newRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil) resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
if err != nil { if err != nil {
return err return err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
return nil return nil
} }
@ -84,11 +94,13 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
// GetSettings get DNS settings // GetSettings get DNS settings
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings // See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) { func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/settings", nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSSettings](resp) ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err return &ret, err
} }
@ -100,11 +112,13 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes)) resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSSettings](resp) ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err return &ret, err
} }

View File

@ -14,11 +14,13 @@ type EventsAPI struct {
// List list all events // List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events // See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) { func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/events", nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Event](resp) ret, err := parseResponse[[]api.Event](resp)
return ret, err return ret, err
} }

View File

@ -14,11 +14,13 @@ type GeoLocationAPI struct {
// ListCountries list all country codes // ListCountries list all country codes
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes // See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) { func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries", nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Country](resp) ret, err := parseResponse[[]api.Country](resp)
return ret, err return ret, err
} }
@ -26,11 +28,13 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
// ListCountryCities Get a list of all English city names for a given country code // ListCountryCities Get a list of all English city names for a given country code
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country // See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) { func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil) resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
}
ret, err := parseResponse[[]api.City](resp) ret, err := parseResponse[[]api.City](resp)
return ret, err return ret, err
} }

Some files were not shown because too many files have changed in this diff Show More