mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-26 12:42:32 +02:00
Merge remote-tracking branch 'origin/main' into feat/multiple-profile
This commit is contained in:
commit
7fae260faa
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@ -37,17 +37,22 @@ If yes, which one?
|
|||||||
|
|
||||||
**Debug output**
|
**Debug output**
|
||||||
|
|
||||||
To help us resolve the problem, please attach the following debug output
|
To help us resolve the problem, please attach the following anonymized status output
|
||||||
|
|
||||||
netbird status -dA
|
netbird status -dA
|
||||||
|
|
||||||
As well as the file created by
|
Create and upload a debug bundle, and share the returned file key:
|
||||||
|
|
||||||
|
netbird debug for 1m -AS -U
|
||||||
|
|
||||||
|
*Uploaded files are automatically deleted after 30 days.*
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, create the file only and attach it here manually:
|
||||||
|
|
||||||
netbird debug for 1m -AS
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
We advise reviewing the anonymized output for any remaining personal information.
|
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
**Have you tried these troubleshooting steps?**
|
**Have you tried these troubleshooting steps?**
|
||||||
|
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||||
- [ ] Checked for newer NetBird versions
|
- [ ] Checked for newer NetBird versions
|
||||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
- [ ] Restarted the NetBird client
|
- [ ] Restarted the NetBird client
|
||||||
- [ ] Disabled other VPN software
|
- [ ] Disabled other VPN software
|
||||||
- [ ] Checked firewall settings
|
- [ ] Checked firewall settings
|
||||||
|
|
||||||
|
@ -179,6 +179,7 @@ jobs:
|
|||||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
grep DisablePromptLogin management.json | grep 'true'
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
|
grep LoginFlag management.json | grep 0
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -98,11 +99,11 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
DnsLabels: dnsLabelsReq,
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@ -195,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -243,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isLinuxRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
}
|
}
|
||||||
|
@ -26,22 +26,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
externalIPMapFlag = "external-ip-map"
|
externalIPMapFlag = "external-ip-map"
|
||||||
dnsResolverAddress = "dns-resolver-address"
|
dnsResolverAddress = "dns-resolver-address"
|
||||||
enableRosenpassFlag = "enable-rosenpass"
|
enableRosenpassFlag = "enable-rosenpass"
|
||||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||||
preSharedKeyFlag = "preshared-key"
|
preSharedKeyFlag = "preshared-key"
|
||||||
interfaceNameFlag = "interface-name"
|
interfaceNameFlag = "interface-name"
|
||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
networkMonitorFlag = "network-monitor"
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
blockLANAccessFlag = "block-lan-access"
|
blockLANAccessFlag = "block-lan-access"
|
||||||
uploadBundle = "upload-bundle"
|
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||||
uploadBundleURL = "upload-bundle-url"
|
uploadBundle = "upload-bundle"
|
||||||
|
uploadBundleURL = "upload-bundle-url"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -80,6 +81,7 @@ var (
|
|||||||
blockLANAccess bool
|
blockLANAccess bool
|
||||||
debugUploadBundle bool
|
debugUploadBundle bool
|
||||||
debugUploadBundleURL string
|
debugUploadBundleURL string
|
||||||
|
lazyConnEnabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@ -184,6 +186,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||||
|
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||||
|
@ -44,7 +44,7 @@ func init() {
|
|||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "idle", "connecting", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
if len(ipsFilter) > 0 {
|
||||||
|
@ -194,6 +194,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.BlockLANAccess = &blockLANAccess
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
|
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -262,17 +266,17 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
DnsLabels: dnsLabels,
|
DnsLabels: dnsLabels,
|
||||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@ -332,6 +336,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.BlockLanAccess = &blockLANAccess
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
|
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
@ -201,14 +201,30 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
func (c *KernelConfigurer) Close() {
|
func (c *KernelConfigurer) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
stats := make(map[string]WGStats)
|
||||||
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
}
|
}
|
||||||
return WGStats{
|
defer func() {
|
||||||
LastHandshake: peer.LastHandshakeTime,
|
err = wg.Close()
|
||||||
TxBytes: peer.TransmitBytes,
|
if err != nil {
|
||||||
RxBytes: peer.ReceiveBytes,
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
}, nil
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range wgDevice.Peers {
|
||||||
|
stats[peer.PublicKey.String()] = WGStats{
|
||||||
|
LastHandshake: peer.LastHandshakeTime,
|
||||||
|
TxBytes: peer.TransmitBytes,
|
||||||
|
RxBytes: peer.ReceiveBytes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -17,6 +18,13 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||||
|
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||||
|
ipcKeyTxBytes = "tx_bytes"
|
||||||
|
ipcKeyRxBytes = "rx_bytes"
|
||||||
|
)
|
||||||
|
|
||||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||||
|
|
||||||
type WGUSPConfigurer struct {
|
type WGUSPConfigurer struct {
|
||||||
@ -217,91 +225,75 @@ func (t *WGUSPConfigurer) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
ipc, err := t.device.IpcGet()
|
ipc, err := t.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
return nil, fmt.Errorf("ipc get: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
return parseTransfers(ipc)
|
||||||
"last_handshake_time_sec",
|
|
||||||
"last_handshake_time_nsec",
|
|
||||||
"tx_bytes",
|
|
||||||
"rx_bytes",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
|
||||||
}
|
|
||||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
|
||||||
}
|
|
||||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return WGStats{
|
|
||||||
LastHandshake: time.Unix(sec, nsec),
|
|
||||||
TxBytes: txBytes,
|
|
||||||
RxBytes: rxBytes,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
stats := make(map[string]WGStats)
|
||||||
if err != nil {
|
var (
|
||||||
return nil, fmt.Errorf("parse key: %w", err)
|
currentKey string
|
||||||
}
|
currentStats WGStats
|
||||||
|
hasPeer bool
|
||||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
)
|
||||||
|
lines := strings.Split(ipc, "\n")
|
||||||
lines := strings.Split(ipcInput, "\n")
|
|
||||||
|
|
||||||
configFound := map[string]string{}
|
|
||||||
foundPeer := false
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
// If we're within the details of the found peer and encounter another public key,
|
// If we're within the details of the found peer and encounter another public key,
|
||||||
// this means we're starting another peer's details. So, stop.
|
// this means we're starting another peer's details. So, stop.
|
||||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
if strings.HasPrefix(line, "public_key=") {
|
||||||
break
|
peerID := strings.TrimPrefix(line, "public_key=")
|
||||||
}
|
h, err := hex.DecodeString(peerID)
|
||||||
|
if err != nil {
|
||||||
// Identify the peer with the specific public key
|
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
|
||||||
foundPeer = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, key := range searchConfigKeys {
|
|
||||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
|
||||||
v := strings.SplitN(line, "=", 2)
|
|
||||||
configFound[v[0]] = v[1]
|
|
||||||
}
|
}
|
||||||
|
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||||
|
currentStats = WGStats{} // Reset stats for the new peer
|
||||||
|
hasPeer = true
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasPeer {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key := strings.SplitN(line, "=", 2)
|
||||||
|
if len(key) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch key[0] {
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
hs, err := toLastHandshake(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
currentStats.LastHandshake = hs
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
rxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.RxBytes = rxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
TxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.TxBytes = TxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: use multierr
|
return stats, nil
|
||||||
for _, key := range searchConfigKeys {
|
|
||||||
if _, ok := configFound[key]; !ok {
|
|
||||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundPeer {
|
|
||||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return configFound, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||||
@ -355,6 +347,18 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||||
|
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||||
|
}
|
||||||
|
return time.Unix(sec, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toBytes(s string) (int64, error) {
|
||||||
|
return strconv.ParseInt(s, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
|
@ -2,10 +2,8 @@ package configurer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@ -34,58 +32,35 @@ errno=0
|
|||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func Test_findPeerInfo(t *testing.T) {
|
func Test_parseTransfers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
peerKey string
|
peerKey string
|
||||||
searchKeys []string
|
want WGStats
|
||||||
want map[string]string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||||
searchKeys: []string{"tx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 0,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 0,
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 38333,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 2224,
|
||||||
"rx_bytes": "2224",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "lastpeer",
|
name: "lastpeer",
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 1212111,
|
||||||
"tx_bytes": "1212111",
|
RxBytes: 1929999999,
|
||||||
"rx_bytes": "1929999999",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "peer not found",
|
|
||||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
|
||||||
searchKeys: nil,
|
|
||||||
want: nil,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "key not found",
|
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
|
||||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
|
||||||
want: map[string]string{
|
|
||||||
"tx_bytes": "1212111",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
|||||||
key, err := wgtypes.NewKey(res)
|
key, err := wgtypes.NewKey(res)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
stats, err := parseTransfers(ipcFixture)
|
||||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
if err != nil {
|
||||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, ok := stats[key.String()]
|
||||||
|
if !ok {
|
||||||
|
require.True(t, ok)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, tt.want, stat)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,5 +16,5 @@ type WGConfigurer interface {
|
|||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
AddAllowedIP(peerKey string, allowedIP string) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||||
Close()
|
Close()
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
}
|
}
|
||||||
|
@ -212,9 +212,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
|||||||
return w.tun.Device()
|
return w.tun.Device()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) waitUntilRemoved() error {
|
func (w *WGIface) waitUntilRemoved() error {
|
||||||
|
@ -24,6 +24,8 @@
|
|||||||
|
|
||||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
|
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@ -49,6 +51,10 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
!include "MUI2.nsh"
|
||||||
|
!include LogicLib.nsh
|
||||||
|
!include "nsDialogs.nsh"
|
||||||
|
|
||||||
!define MUI_ICON "${ICON}"
|
!define MUI_ICON "${ICON}"
|
||||||
!define MUI_UNICON "${ICON}"
|
!define MUI_UNICON "${ICON}"
|
||||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
@ -58,9 +64,6 @@ ShowInstDetails Show
|
|||||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
!include "MUI2.nsh"
|
|
||||||
!include LogicLib.nsh
|
|
||||||
|
|
||||||
!define MUI_ABORTWARNING
|
!define MUI_ABORTWARNING
|
||||||
!define MUI_UNABORTWARNING
|
!define MUI_UNABORTWARNING
|
||||||
|
|
||||||
@ -70,13 +73,16 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
; Custom page for autostart checkbox
|
|
||||||
Page custom AutostartPage AutostartPageLeave
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
|
|
||||||
|
!insertmacro MUI_UNPAGE_WELCOME
|
||||||
|
|
||||||
|
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_CONFIRM
|
!insertmacro MUI_UNPAGE_CONFIRM
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_INSTFILES
|
!insertmacro MUI_UNPAGE_INSTFILES
|
||||||
@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
|
|||||||
Var AutostartCheckbox
|
Var AutostartCheckbox
|
||||||
Var AutostartEnabled
|
Var AutostartEnabled
|
||||||
|
|
||||||
|
; Variables for uninstall data deletion option
|
||||||
|
Var DeleteDataCheckbox
|
||||||
|
Var DeleteDataEnabled
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
; Function to create the autostart options page
|
; Function to create the autostart options page
|
||||||
@ -104,8 +114,8 @@ Function AutostartPage
|
|||||||
|
|
||||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
Pop $AutostartCheckbox
|
Pop $AutostartCheckbox
|
||||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
${NSD_Check} $AutostartCheckbox
|
||||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
StrCpy $AutostartEnabled "1"
|
||||||
|
|
||||||
nsDialogs::Show
|
nsDialogs::Show
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
@ -115,6 +125,30 @@ Function AutostartPageLeave
|
|||||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to create the uninstall data deletion page
|
||||||
|
Function un.DeleteDataPage
|
||||||
|
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
||||||
|
|
||||||
|
nsDialogs::Create 1018
|
||||||
|
Pop $0
|
||||||
|
|
||||||
|
${If} $0 == error
|
||||||
|
Abort
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
||||||
|
Pop $DeleteDataCheckbox
|
||||||
|
${NSD_Uncheck} $DeleteDataCheckbox
|
||||||
|
StrCpy $DeleteDataEnabled "0"
|
||||||
|
|
||||||
|
nsDialogs::Show
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to handle leaving the data deletion page
|
||||||
|
Function un.DeleteDataPageLeave
|
||||||
|
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@ -176,10 +210,10 @@ ${EndIf}
|
|||||||
FunctionEnd
|
FunctionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
Section -MainProgram
|
Section -MainProgram
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
# SetOverwrite ifnewer
|
# SetOverwrite ifnewer
|
||||||
SetOutPath "$INSTDIR"
|
SetOutPath "$INSTDIR"
|
||||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||||
SectionEnd
|
SectionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
@ -225,31 +259,58 @@ SectionEnd
|
|||||||
Section Uninstall
|
Section Uninstall
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
|
|
||||||
|
DetailPrint "Stopping Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||||
|
DetailPrint "Uninstalling Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
# kill ui client
|
DetailPrint "Terminating Netbird UI process..."
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
|
DetailPrint "Removing autostart registry entry if exists..."
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
|
; Handle data deletion based on checkbox
|
||||||
|
DetailPrint "Checking if user requested data deletion..."
|
||||||
|
${If} $DeleteDataEnabled == "1"
|
||||||
|
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
||||||
|
ClearErrors
|
||||||
|
RMDir /r "${NETBIRD_DATA_DIR}"
|
||||||
|
IfErrors 0 +2 ; If no errors, jump over the message
|
||||||
|
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
||||||
|
DetailPrint "Netbird data directory removal complete."
|
||||||
|
${Else}
|
||||||
|
DetailPrint "User did not opt to delete Netbird data."
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
|
DetailPrint "Waiting for service handle to be released..."
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|
||||||
|
DetailPrint "Deleting application files..."
|
||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
Delete "$INSTDIR\opengl32.dll"
|
||||||
|
DetailPrint "Removing application directory..."
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Removing shortcuts..."
|
||||||
SetShellVarContext all
|
SetShellVarContext all
|
||||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||||
|
|
||||||
|
DetailPrint "Removing registry keys..."
|
||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
|
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||||
|
|
||||||
|
DetailPrint "Removing application directory from PATH..."
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::DeleteValue "path" "$INSTDIR"
|
EnVar::DeleteValue "path" "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Uninstallation finished."
|
||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,12 +76,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
d.applyPeerACLs(networkMap)
|
||||||
|
|
||||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
|
||||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
|
||||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
|
||||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
|
||||||
log.Errorf("failed to set legacy management flag: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
|
@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
|
}
|
||||||
|
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
@ -7,15 +7,36 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPromptLogin(t *testing.T) {
|
func TestPromptLogin(t *testing.T) {
|
||||||
|
const (
|
||||||
|
promptLogin = "prompt=login"
|
||||||
|
maxAge0 = "max_age=0"
|
||||||
|
)
|
||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
prompt bool
|
loginFlag mgm.LoginFlag
|
||||||
|
disablePromptLogin bool
|
||||||
|
expect string
|
||||||
}{
|
}{
|
||||||
{"PromptLogin", true},
|
{
|
||||||
{"NoPromptLogin", false},
|
name: "Prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
expect: promptLogin,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max age 0 login",
|
||||||
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
|
expect: maxAge0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disable prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
disablePromptLogin: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
DisablePromptLogin: !tc.prompt,
|
LoginFlag: tc.loginFlag,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
pattern := "prompt=login"
|
|
||||||
if tc.prompt {
|
if !tc.disablePromptLogin {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||||
} else {
|
} else {
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||||
|
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -74,6 +74,8 @@ type ConfigInput struct {
|
|||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
DNSLabels domain.List
|
DNSLabels domain.List
|
||||||
|
|
||||||
|
LazyConnectionEnabled *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@ -138,6 +140,8 @@ type Config struct {
|
|||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
|
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
@ -524,6 +528,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||||
|
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||||
|
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
303
client/internal/conn_mgr.go
Normal file
303
client/internal/conn_mgr.go
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||||
|
//
|
||||||
|
// The connection manager is responsible for:
|
||||||
|
// - Managing lazy connections via the lazyConnManager
|
||||||
|
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||||
|
// - Handling connection establishment based on peer signaling
|
||||||
|
//
|
||||||
|
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||||
|
type ConnMgr struct {
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
iface lazyconn.WGIface
|
||||||
|
dispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
enabledLocally bool
|
||||||
|
|
||||||
|
lazyConnMgr *manager.Manager
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||||
|
e := &ConnMgr{
|
||||||
|
peerStore: peerStore,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
iface: iface,
|
||||||
|
dispatcher: dispatcher,
|
||||||
|
}
|
||||||
|
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||||
|
e.enabledLocally = true
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||||
|
func (e *ConnMgr) Start(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
log.Errorf("lazy connection manager is already started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.enabledLocally {
|
||||||
|
log.Infof("lazy connection manager is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
|
||||||
|
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||||
|
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||||
|
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||||
|
// do not disable lazy connection manager if it was enabled by env var
|
||||||
|
if e.enabledLocally {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
// if the lazy connection manager is already started, do not start it again
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
return e.addPeersToLazyConnManager(ctx)
|
||||||
|
} else {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||||
|
e.closeManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
|
||||||
|
func (e *ConnMgr) SetExcludeList(peerIDs []string) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
|
||||||
|
|
||||||
|
for _, peerID := range peerIDs {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
|
||||||
|
for _, peerID := range added {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
|
||||||
|
if err := peerConn.Open(e.ctx); err != nil {
|
||||||
|
peerConn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
|
||||||
|
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !lazyconn.IsSupported(conn.AgentVersionString()) {
|
||||||
|
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerKey,
|
||||||
|
AllowedIPs: conn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: conn.ConnID(),
|
||||||
|
Log: conn.Log,
|
||||||
|
}
|
||||||
|
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||||
|
if err != nil {
|
||||||
|
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if excluded {
|
||||||
|
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("peer added to lazy conn manager")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||||
|
conn, ok := e.peerStore.Remove(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyConnMgr.RemovePeer(peerKey)
|
||||||
|
conn.Log.Infof("removed peer from lazy conn manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||||
|
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
|
||||||
|
conn.Log.Infof("activated peer from inactive state")
|
||||||
|
if err := conn.Open(e.ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) Close() {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.ctxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
|
||||||
|
cfg := manager.Config{
|
||||||
|
InactivityThreshold: inactivityThresholdEnv(),
|
||||||
|
}
|
||||||
|
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
e.ctx = ctx
|
||||||
|
e.ctxCancel = cancel
|
||||||
|
|
||||||
|
e.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.wg.Done()
|
||||||
|
e.lazyConnMgr.Start(ctx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
|
||||||
|
peers := e.peerStore.PeersPubKey()
|
||||||
|
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
|
||||||
|
for _, peerID := range peers {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.ctxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
|
||||||
|
for _, peerID := range e.peerStore.PeersPubKey() {
|
||||||
|
e.peerStore.PeerConnOpen(ctx, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||||
|
return e.lazyConnMgr != nil && e.ctxCancel != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func inactivityThresholdEnv() *time.Duration {
|
||||||
|
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||||
|
if envValue == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedMinutes, err := strconv.Atoi(envValue)
|
||||||
|
if err != nil || parsedMinutes <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d := time.Duration(parsedMinutes) * time.Minute
|
||||||
|
return &d
|
||||||
|
}
|
@ -440,7 +440,8 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DisableDNS: config.DisableDNS,
|
DisableDNS: config.DisableDNS,
|
||||||
DisableFirewall: config.DisableFirewall,
|
DisableFirewall: config.DisableFirewall,
|
||||||
|
|
||||||
BlockLANAccess: config.BlockLANAccess,
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@ -481,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
return signalClient, nil
|
return signalClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
@ -376,6 +376,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
|
||||||
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -9,6 +8,7 @@ import (
|
|||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
r.SetQuestion("example.com.", dns.TypeA)
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
// Create test writer
|
// Create test writer
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup expectations - only highest priority handler should be called
|
// Setup expectations - only highest priority handler should be called
|
||||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
// Create and execute request
|
// Create and execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
// Verify expectations
|
// Verify expectations
|
||||||
@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
// Verify all handlers were called in order
|
// Verify all handlers were called in order
|
||||||
@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3.AssertExpectations(t)
|
handler3.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
|
||||||
type mockResponseWriter struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
|
||||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
|
||||||
func (m *mockResponseWriter) Close() error { return nil }
|
|
||||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
|
||||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
|
||||||
func (m *mockResponseWriter) Hijack() {}
|
|
||||||
|
|
||||||
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup expectations
|
// Setup expectations
|
||||||
for priority, handler := range handlers {
|
for priority, handler := range handlers {
|
||||||
@ -471,7 +457,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state
|
// Test 1: Initial state
|
||||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Highest priority handler (routeHandler) should be called
|
// Highest priority handler (routeHandler) should be called
|
||||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
@ -490,7 +476,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Test 2: Remove highest priority handler
|
// Test 2: Remove highest priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
|
|
||||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now middle priority handler (matchHandler) should be called
|
// Now middle priority handler (matchHandler) should be called
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||||
@ -506,7 +492,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
@ -519,7 +505,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||||
|
|
||||||
for _, m := range mocks {
|
for _, m := range mocks {
|
||||||
@ -675,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
// Execute request
|
// Execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
chain.ServeDNS(&test.MockResponseWriter{}, r)
|
||||||
|
|
||||||
// Verify each handler was called exactly as expected
|
// Verify each handler was called exactly as expected
|
||||||
for _, h := range tt.addHandlers {
|
for _, h := range tt.addHandlers {
|
||||||
@ -819,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup handler expectations
|
// Setup handler expectations
|
||||||
for pattern, handler := range handlers {
|
for pattern, handler := range handlers {
|
||||||
@ -969,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
handler := &nbdns.MockHandler{}
|
handler := &nbdns.MockHandler{}
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// First verify no handler is called before adding any
|
// First verify no handler is called before adding any
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
@ -1,130 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type registrationMap map[string]struct{}
|
|
||||||
|
|
||||||
type localResolver struct {
|
|
||||||
registeredMap registrationMap
|
|
||||||
records sync.Map // key: string (domain_class_type), value: []dns.RR
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *localResolver) MatchSubdomains() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *localResolver) stop() {
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the local resolver
|
|
||||||
func (d *localResolver) String() string {
|
|
||||||
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
|
||||||
func (d *localResolver) id() handlerID {
|
|
||||||
return "local-resolver"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
|
||||||
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
if len(r.Question) > 0 {
|
|
||||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
|
||||||
}
|
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
|
||||||
replyMessage.SetReply(r)
|
|
||||||
replyMessage.RecursionAvailable = true
|
|
||||||
|
|
||||||
// lookup all records matching the question
|
|
||||||
records := d.lookupRecords(r)
|
|
||||||
if len(records) > 0 {
|
|
||||||
replyMessage.Rcode = dns.RcodeSuccess
|
|
||||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
|
||||||
} else {
|
|
||||||
replyMessage.Rcode = dns.RcodeNameError
|
|
||||||
}
|
|
||||||
|
|
||||||
err := w.WriteMsg(replyMessage)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("got an error while writing the local resolver response, error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
|
||||||
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
|
||||||
if len(r.Question) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
question := r.Question[0]
|
|
||||||
question.Name = strings.ToLower(question.Name)
|
|
||||||
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
|
|
||||||
|
|
||||||
value, found := d.records.Load(key)
|
|
||||||
if !found {
|
|
||||||
// alternatively check if we have a cname
|
|
||||||
if question.Qtype != dns.TypeCNAME {
|
|
||||||
r.Question[0].Qtype = dns.TypeCNAME
|
|
||||||
return d.lookupRecords(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
records, ok := value.([]dns.RR)
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// if there's more than one record, rotate them (round-robin)
|
|
||||||
if len(records) > 1 {
|
|
||||||
first := records[0]
|
|
||||||
records = append(records[1:], first)
|
|
||||||
d.records.Store(key, records)
|
|
||||||
}
|
|
||||||
|
|
||||||
return records
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerRecord stores a new record by appending it to any existing list
|
|
||||||
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
|
|
||||||
rr, err := dns.NewRR(record.String())
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("register record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rr.Header().Rdlength = record.Len()
|
|
||||||
header := rr.Header()
|
|
||||||
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
|
|
||||||
|
|
||||||
// load any existing slice of records, then append
|
|
||||||
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
|
|
||||||
records := existing.([]dns.RR)
|
|
||||||
records = append(records, rr)
|
|
||||||
|
|
||||||
// store updated slice
|
|
||||||
d.records.Store(key, records)
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteRecord removes *all* records under the recordKey.
|
|
||||||
func (d *localResolver) deleteRecord(recordKey string) {
|
|
||||||
d.records.Delete(dns.Fqdn(recordKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildRecordKey consistently generates a key: name_class_type
|
|
||||||
func buildRecordKey(name string, class, qType uint16) string {
|
|
||||||
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *localResolver) probeAvailability() {}
|
|
149
client/internal/dns/local/local.go
Normal file
149
client/internal/dns/local/local.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package local
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Resolver struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
records map[dns.Question][]dns.RR
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResolver() *Resolver {
|
||||||
|
return &Resolver{
|
||||||
|
records: make(map[dns.Question][]dns.RR),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) MatchSubdomains() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the local resolver
|
||||||
|
func (d *Resolver) String() string {
|
||||||
|
return fmt.Sprintf("local resolver [%d records]", len(d.records))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) Stop() {}
|
||||||
|
|
||||||
|
// ID returns the unique handler ID
|
||||||
|
func (d *Resolver) ID() types.HandlerID {
|
||||||
|
return "local-resolver"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) ProbeAvailability() {}
|
||||||
|
|
||||||
|
// ServeDNS handles a DNS request
|
||||||
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
log.Debugf("received local resolver request with no question")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
question := r.Question[0]
|
||||||
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||||
|
|
||||||
|
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
||||||
|
|
||||||
|
replyMessage := &dns.Msg{}
|
||||||
|
replyMessage.SetReply(r)
|
||||||
|
replyMessage.RecursionAvailable = true
|
||||||
|
|
||||||
|
// lookup all records matching the question
|
||||||
|
records := d.lookupRecords(question)
|
||||||
|
if len(records) > 0 {
|
||||||
|
replyMessage.Rcode = dns.RcodeSuccess
|
||||||
|
replyMessage.Answer = append(replyMessage.Answer, records...)
|
||||||
|
} else {
|
||||||
|
// TODO: return success if we have a different record type for the same name, relevant for search domains
|
||||||
|
replyMessage.Rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(replyMessage); err != nil {
|
||||||
|
log.Warnf("failed to write the local resolver response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||||
|
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
||||||
|
d.mu.RLock()
|
||||||
|
records, found := d.records[question]
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
d.mu.RUnlock()
|
||||||
|
// alternatively check if we have a cname
|
||||||
|
if question.Qtype != dns.TypeCNAME {
|
||||||
|
question.Qtype = dns.TypeCNAME
|
||||||
|
return d.lookupRecords(question)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
recordsCopy := slices.Clone(records)
|
||||||
|
d.mu.RUnlock()
|
||||||
|
|
||||||
|
// if there's more than one record, rotate them (round-robin)
|
||||||
|
if len(recordsCopy) > 1 {
|
||||||
|
d.mu.Lock()
|
||||||
|
records = d.records[question]
|
||||||
|
if len(records) > 1 {
|
||||||
|
first := records[0]
|
||||||
|
records = append(records[1:], first)
|
||||||
|
d.records[question] = records
|
||||||
|
}
|
||||||
|
d.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return recordsCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
maps.Clear(d.records)
|
||||||
|
|
||||||
|
for _, rec := range update {
|
||||||
|
if err := d.registerRecord(rec); err != nil {
|
||||||
|
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRecord stores a new record by appending it to any existing list
|
||||||
|
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
return d.registerRecord(record)
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerRecord performs the registration with the lock already held
|
||||||
|
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
||||||
|
rr, err := dns.NewRR(record.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("register record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr.Header().Rdlength = record.Len()
|
||||||
|
header := rr.Header()
|
||||||
|
q := dns.Question{
|
||||||
|
Name: strings.ToLower(dns.Fqdn(header.Name)),
|
||||||
|
Qtype: header.Rrtype,
|
||||||
|
Qclass: header.Class,
|
||||||
|
}
|
||||||
|
|
||||||
|
d.records[q] = append(d.records[q], rr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
472
client/internal/dns/local/local_test.go
Normal file
472
client/internal/dns/local/local_test.go
Normal file
@ -0,0 +1,472 @@
|
|||||||
|
package local
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||||
|
recordA := nbdns.SimpleRecord{
|
||||||
|
Name: "peera.netbird.cloud.",
|
||||||
|
Type: 1,
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "1.2.3.4",
|
||||||
|
}
|
||||||
|
|
||||||
|
recordCNAME := nbdns.SimpleRecord{
|
||||||
|
Name: "peerb.netbird.cloud.",
|
||||||
|
Type: 5,
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "www.netbird.io",
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputRecord nbdns.SimpleRecord
|
||||||
|
inputMSG *dns.Msg
|
||||||
|
responseShouldBeNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should Resolve A Record",
|
||||||
|
inputRecord: recordA,
|
||||||
|
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Resolve CNAME Record",
|
||||||
|
inputRecord: recordCNAME,
|
||||||
|
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should Not Write When Not Found A Record",
|
||||||
|
inputRecord: recordA,
|
||||||
|
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
||||||
|
responseShouldBeNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
_ = resolver.RegisterRecord(testCase.inputRecord)
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
||||||
|
|
||||||
|
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||||
|
if testCase.responseShouldBeNil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("should write a response message")
|
||||||
|
}
|
||||||
|
|
||||||
|
answerString := responseMSG.Answer[0].String()
|
||||||
|
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
||||||
|
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
||||||
|
}
|
||||||
|
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
||||||
|
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
||||||
|
}
|
||||||
|
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
||||||
|
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_Update_StaleRecord verifies that updating
|
||||||
|
// a record correctly replaces the old one, preventing stale entries.
|
||||||
|
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
||||||
|
recordName := "host.example.com."
|
||||||
|
recordType := dns.TypeA
|
||||||
|
recordClass := dns.ClassINET
|
||||||
|
|
||||||
|
record1 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
|
||||||
|
}
|
||||||
|
record2 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
|
||||||
|
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
update1 := []nbdns.SimpleRecord{record1}
|
||||||
|
update2 := []nbdns.SimpleRecord{record2}
|
||||||
|
|
||||||
|
// Apply first update
|
||||||
|
resolver.Update(update1)
|
||||||
|
|
||||||
|
// Verify first update
|
||||||
|
resolver.mu.RLock()
|
||||||
|
rrSlice1, found1 := resolver.records[recordKey]
|
||||||
|
resolver.mu.RUnlock()
|
||||||
|
|
||||||
|
require.True(t, found1, "Record key %s not found after first update", recordKey)
|
||||||
|
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
|
||||||
|
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||||
|
|
||||||
|
// Apply second update
|
||||||
|
resolver.Update(update2)
|
||||||
|
|
||||||
|
// Verify second update
|
||||||
|
resolver.mu.RLock()
|
||||||
|
rrSlice2, found2 := resolver.records[recordKey]
|
||||||
|
resolver.mu.RUnlock()
|
||||||
|
|
||||||
|
require.True(t, found2, "Record key %s not found after second update", recordKey)
|
||||||
|
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
|
||||||
|
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
|
||||||
|
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
|
||||||
|
// with the same question are stored properly
|
||||||
|
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
recordName := "multi.example.com."
|
||||||
|
recordType := dns.TypeA
|
||||||
|
|
||||||
|
// Create two records with the same name and type but different IPs
|
||||||
|
record1 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
|
||||||
|
}
|
||||||
|
record2 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
|
||||||
|
}
|
||||||
|
|
||||||
|
update := []nbdns.SimpleRecord{record1, record2}
|
||||||
|
|
||||||
|
// Apply update with both records
|
||||||
|
resolver.Update(update)
|
||||||
|
|
||||||
|
// Create question that matches both records
|
||||||
|
question := dns.Question{
|
||||||
|
Name: recordName,
|
||||||
|
Qtype: recordType,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify both records are stored
|
||||||
|
resolver.mu.RLock()
|
||||||
|
records, found := resolver.records[question]
|
||||||
|
resolver.mu.RUnlock()
|
||||||
|
|
||||||
|
require.True(t, found, "Records for question %v not found", question)
|
||||||
|
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
|
||||||
|
|
||||||
|
// Verify both record data values are present
|
||||||
|
recordStrings := []string{records[0].String(), records[1].String()}
|
||||||
|
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
|
||||||
|
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
|
||||||
|
func TestLocalResolver_RecordRotation(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
recordName := "rotation.example.com."
|
||||||
|
recordType := dns.TypeA
|
||||||
|
|
||||||
|
// Create three records with the same name and type but different IPs
|
||||||
|
record1 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
|
||||||
|
}
|
||||||
|
record2 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
|
||||||
|
}
|
||||||
|
record3 := nbdns.SimpleRecord{
|
||||||
|
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
|
||||||
|
}
|
||||||
|
|
||||||
|
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||||
|
|
||||||
|
// Apply update with all three records
|
||||||
|
resolver.Update(update)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||||
|
|
||||||
|
// First lookup - should return the records in original order
|
||||||
|
var responses [3]*dns.Msg
|
||||||
|
|
||||||
|
// Perform three lookups to verify rotation
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responses[i] = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all three responses contain answers
|
||||||
|
for i, resp := range responses {
|
||||||
|
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||||
|
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the first record in each response is different due to rotation
|
||||||
|
firstRecordIPs := []string{
|
||||||
|
responses[0].Answer[0].String(),
|
||||||
|
responses[1].Answer[0].String(),
|
||||||
|
responses[2].Answer[0].String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each record should be different (rotated)
|
||||||
|
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
|
||||||
|
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
|
||||||
|
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
|
||||||
|
|
||||||
|
// After three rotations, we should have cycled through all records
|
||||||
|
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
|
||||||
|
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
|
||||||
|
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
|
||||||
|
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Create record with lowercase name
|
||||||
|
lowerCaseRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "lower.example.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "10.10.10.10",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create record with mixed case name
|
||||||
|
mixedCaseRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "MiXeD.ExAmPlE.CoM.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "20.20.20.20",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update resolver with the records
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
queryName string
|
||||||
|
expectedRData string
|
||||||
|
shouldResolve bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Query lowercase with lowercase record",
|
||||||
|
queryName: "lower.example.com.",
|
||||||
|
expectedRData: "10.10.10.10",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query uppercase with lowercase record",
|
||||||
|
queryName: "LOWER.EXAMPLE.COM.",
|
||||||
|
expectedRData: "10.10.10.10",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query mixed case with lowercase record",
|
||||||
|
queryName: "LoWeR.eXaMpLe.CoM.",
|
||||||
|
expectedRData: "10.10.10.10",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query lowercase with mixed case record",
|
||||||
|
queryName: "mixed.example.com.",
|
||||||
|
expectedRData: "20.20.20.20",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query uppercase with mixed case record",
|
||||||
|
queryName: "MIXED.EXAMPLE.COM.",
|
||||||
|
expectedRData: "20.20.20.20",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query with different casing pattern",
|
||||||
|
queryName: "mIxEd.ExaMpLe.cOm.",
|
||||||
|
expectedRData: "20.20.20.20",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query non-existent domain",
|
||||||
|
queryName: "nonexistent.example.com.",
|
||||||
|
shouldResolve: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
|
||||||
|
// Create DNS query with the test case name
|
||||||
|
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
|
||||||
|
|
||||||
|
// Create mock response writer to capture the response
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform DNS query
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
|
||||||
|
// Check if we expect a successful resolution
|
||||||
|
if !tc.shouldResolve {
|
||||||
|
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
||||||
|
// Expected no answer, test passes
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we got a response
|
||||||
|
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||||
|
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||||
|
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
answerString := responseMSG.Answer[0].String()
|
||||||
|
assert.Contains(t, answerString, tc.expectedRData,
|
||||||
|
"Answer should contain the expected IP address %s, got: %s",
|
||||||
|
tc.expectedRData, answerString)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
|
||||||
|
// to checking for CNAME records when the requested record type isn't found
|
||||||
|
func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
// Create a CNAME record (but no A record for this name)
|
||||||
|
cnameRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "alias.example.com.",
|
||||||
|
Type: int(dns.TypeCNAME),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "target.example.com.",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an A record for the CNAME target
|
||||||
|
targetRecord := nbdns.SimpleRecord{
|
||||||
|
Name: "target.example.com.",
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: "192.168.100.100",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update resolver with both records
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
queryName string
|
||||||
|
queryType uint16
|
||||||
|
expectedType string
|
||||||
|
expectedRData string
|
||||||
|
shouldResolve bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Directly query CNAME record",
|
||||||
|
queryName: "alias.example.com.",
|
||||||
|
queryType: dns.TypeCNAME,
|
||||||
|
expectedType: "CNAME",
|
||||||
|
expectedRData: "target.example.com.",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query A record but get CNAME fallback",
|
||||||
|
queryName: "alias.example.com.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedType: "CNAME",
|
||||||
|
expectedRData: "target.example.com.",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query AAAA record but get CNAME fallback",
|
||||||
|
queryName: "alias.example.com.",
|
||||||
|
queryType: dns.TypeAAAA,
|
||||||
|
expectedType: "CNAME",
|
||||||
|
expectedRData: "target.example.com.",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query direct A record",
|
||||||
|
queryName: "target.example.com.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
expectedType: "A",
|
||||||
|
expectedRData: "192.168.100.100",
|
||||||
|
shouldResolve: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query non-existent name",
|
||||||
|
queryName: "nonexistent.example.com.",
|
||||||
|
queryType: dns.TypeA,
|
||||||
|
shouldResolve: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var responseMSG *dns.Msg
|
||||||
|
|
||||||
|
// Create DNS query with the test case parameters
|
||||||
|
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
|
||||||
|
|
||||||
|
// Create mock response writer to capture the response
|
||||||
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform DNS query
|
||||||
|
resolver.ServeDNS(responseWriter, msg)
|
||||||
|
|
||||||
|
// Check if we expect a successful resolution
|
||||||
|
if !tc.shouldResolve {
|
||||||
|
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
|
||||||
|
// Expected no resolution, test passes
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we got a successful response
|
||||||
|
require.NotNil(t, responseMSG, "Should have received a response message")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
|
||||||
|
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
|
||||||
|
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
answerString := responseMSG.Answer[0].String()
|
||||||
|
assert.Contains(t, answerString, tc.expectedType,
|
||||||
|
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
|
||||||
|
assert.Contains(t, answerString, tc.expectedRData,
|
||||||
|
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,88 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
|
||||||
recordA := nbdns.SimpleRecord{
|
|
||||||
Name: "peera.netbird.cloud.",
|
|
||||||
Type: 1,
|
|
||||||
Class: nbdns.DefaultClass,
|
|
||||||
TTL: 300,
|
|
||||||
RData: "1.2.3.4",
|
|
||||||
}
|
|
||||||
|
|
||||||
recordCNAME := nbdns.SimpleRecord{
|
|
||||||
Name: "peerb.netbird.cloud.",
|
|
||||||
Type: 5,
|
|
||||||
Class: nbdns.DefaultClass,
|
|
||||||
TTL: 300,
|
|
||||||
RData: "www.netbird.io",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
inputRecord nbdns.SimpleRecord
|
|
||||||
inputMSG *dns.Msg
|
|
||||||
responseShouldBeNil bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Should Resolve A Record",
|
|
||||||
inputRecord: recordA,
|
|
||||||
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Resolve CNAME Record",
|
|
||||||
inputRecord: recordCNAME,
|
|
||||||
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Write When Not Found A Record",
|
|
||||||
inputRecord: recordA,
|
|
||||||
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
|
|
||||||
responseShouldBeNil: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
resolver := &localResolver{
|
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
}
|
|
||||||
_, _ = resolver.registerRecord(testCase.inputRecord)
|
|
||||||
var responseMSG *dns.Msg
|
|
||||||
responseWriter := &mockResponseWriter{
|
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
responseMSG = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resolver.ServeDNS(responseWriter, testCase.inputMSG)
|
|
||||||
|
|
||||||
if responseMSG == nil || len(responseMSG.Answer) == 0 {
|
|
||||||
if testCase.responseShouldBeNil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("should write a response message")
|
|
||||||
}
|
|
||||||
|
|
||||||
answerString := responseMSG.Answer[0].String()
|
|
||||||
if !strings.Contains(answerString, testCase.inputRecord.Name) {
|
|
||||||
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
|
|
||||||
}
|
|
||||||
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
|
|
||||||
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
|
|
||||||
}
|
|
||||||
if !strings.Contains(answerString, testCase.inputRecord.RData) {
|
|
||||||
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,26 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockResponseWriter struct {
|
|
||||||
WriteMsgFunc func(m *dns.Msg) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
|
|
||||||
if rw.WriteMsgFunc != nil {
|
|
||||||
return rw.WriteMsgFunc(m)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
|
||||||
func (rw *mockResponseWriter) Close() error { return nil }
|
|
||||||
func (rw *mockResponseWriter) TsigStatus() error { return nil }
|
|
||||||
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
|
|
||||||
func (rw *mockResponseWriter) Hijack() {}
|
|
@ -15,6 +15,8 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@ -46,8 +48,6 @@ type Server interface {
|
|||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerID string
|
|
||||||
|
|
||||||
type nsGroupsByDomain struct {
|
type nsGroupsByDomain struct {
|
||||||
domain string
|
domain string
|
||||||
groups []*nbdns.NameServerGroup
|
groups []*nbdns.NameServerGroup
|
||||||
@ -61,7 +61,7 @@ type DefaultServer struct {
|
|||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
service service
|
service service
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
localResolver *localResolver
|
localResolver *local.Resolver
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
hostManager hostManager
|
hostManager hostManager
|
||||||
updateSerial uint64
|
updateSerial uint64
|
||||||
@ -84,9 +84,9 @@ type DefaultServer struct {
|
|||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
dns.Handler
|
dns.Handler
|
||||||
stop()
|
Stop()
|
||||||
probeAvailability()
|
ProbeAvailability()
|
||||||
id() handlerID
|
ID() types.HandlerID
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWrapper struct {
|
type handlerWrapper struct {
|
||||||
@ -95,7 +95,7 @@ type handlerWrapper struct {
|
|||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
|
|
||||||
type registeredHandlerMap map[handlerID]handlerWrapper
|
type registeredHandlerMap map[types.HandlerID]handlerWrapper
|
||||||
|
|
||||||
// NewDefaultServer returns a new dns server
|
// NewDefaultServer returns a new dns server
|
||||||
func NewDefaultServer(
|
func NewDefaultServer(
|
||||||
@ -171,16 +171,14 @@ func newDefaultServer(
|
|||||||
handlerChain := NewHandlerChain()
|
handlerChain := NewHandlerChain()
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
disableSys: disableSys,
|
disableSys: disableSys,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: handlerChain,
|
handlerChain: handlerChain,
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: local.NewResolver(),
|
||||||
registeredMap: make(registrationMap),
|
|
||||||
},
|
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
stateManager: stateManager,
|
stateManager: stateManager,
|
||||||
@ -403,7 +401,7 @@ func (s *DefaultServer) ProbeAvailability() {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(mux handlerWithStop) {
|
go func(mux handlerWithStop) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
mux.probeAvailability()
|
mux.ProbeAvailability()
|
||||||
}(mux.handler)
|
}(mux.handler)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@ -420,7 +418,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("local handler updater: %w", err)
|
return fmt.Errorf("local handler updater: %w", err)
|
||||||
}
|
}
|
||||||
@ -434,7 +432,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.updateMux(muxUpdates)
|
s.updateMux(muxUpdates)
|
||||||
|
|
||||||
// register local records
|
// register local records
|
||||||
s.updateLocalResolver(localRecordsByDomain)
|
s.localResolver.Update(localRecords)
|
||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
@ -516,11 +514,9 @@ func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
||||||
customZones []nbdns.CustomZone,
|
|
||||||
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {
|
|
||||||
var muxUpdates []handlerWrapper
|
var muxUpdates []handlerWrapper
|
||||||
localRecords := make(map[string][]nbdns.SimpleRecord)
|
var localRecords []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
if len(customZone.Records) == 0 {
|
if len(customZone.Records) == 0 {
|
||||||
@ -534,17 +530,13 @@ func (s *DefaultServer) buildLocalHandlerUpdate(
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
})
|
})
|
||||||
|
|
||||||
// group all records under this domain
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
var class uint16 = dns.ClassINET
|
|
||||||
if record.Class != nbdns.DefaultClass {
|
if record.Class != nbdns.DefaultClass {
|
||||||
log.Warnf("received an invalid class type: %s", record.Class)
|
log.Warnf("received an invalid class type: %s", record.Class)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// zone records contain the fqdn, so we can just flatten them
|
||||||
key := buildRecordKey(record.Name, class, uint16(record.Type))
|
localRecords = append(localRecords, record)
|
||||||
|
|
||||||
localRecords[key] = append(localRecords[key], record)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -627,7 +619,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(handler.upstreamServers) == 0 {
|
if len(handler.upstreamServers) == 0 {
|
||||||
handler.stop()
|
handler.Stop()
|
||||||
log.Errorf("received a nameserver group with an invalid nameserver list")
|
log.Errorf("received a nameserver group with an invalid nameserver list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -656,7 +648,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||||
for _, existing := range s.dnsMuxMap {
|
for _, existing := range s.dnsMuxMap {
|
||||||
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
s.deregisterHandler([]string{existing.domain}, existing.priority)
|
||||||
existing.handler.stop()
|
existing.handler.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
muxUpdateMap := make(registeredHandlerMap)
|
muxUpdateMap := make(registeredHandlerMap)
|
||||||
@ -667,7 +659,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
containsRootUpdate = true
|
containsRootUpdate = true
|
||||||
}
|
}
|
||||||
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
s.registerHandler([]string{update.domain}, update.handler, update.priority)
|
||||||
muxUpdateMap[update.handler.id()] = update
|
muxUpdateMap[update.handler.ID()] = update
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there's no root update and we had a root handler, restore it
|
// If there's no root update and we had a root handler, restore it
|
||||||
@ -683,33 +675,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
|||||||
s.dnsMuxMap = muxUpdateMap
|
s.dnsMuxMap = muxUpdateMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
|
|
||||||
// remove old records that are no longer present
|
|
||||||
for key := range s.localResolver.registeredMap {
|
|
||||||
_, found := update[key]
|
|
||||||
if !found {
|
|
||||||
s.localResolver.deleteRecord(key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap := make(registrationMap)
|
|
||||||
for _, recs := range update {
|
|
||||||
for _, rec := range recs {
|
|
||||||
// convert the record to a dns.RR and register
|
|
||||||
key, err := s.localResolver.registerRecord(rec)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("got an error while registering the record (%s), error: %v",
|
|
||||||
rec.String(), err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedMap[key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localResolver.registeredMap = updatedMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNSHostPort(ns nbdns.NameServer) string {
|
func getNSHostPort(ns nbdns.NameServer) string {
|
||||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/local"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@ -107,6 +110,7 @@ func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateDNSServer(t *testing.T) {
|
func TestUpdateDNSServer(t *testing.T) {
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
{
|
{
|
||||||
IP: netip.MustParseAddr("8.8.8.8"),
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
@ -120,22 +124,21 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dummyHandler := &localResolver{}
|
dummyHandler := local.NewResolver()
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
initUpstreamMap registeredHandlerMap
|
initUpstreamMap registeredHandlerMap
|
||||||
initLocalMap registrationMap
|
initLocalRecords []nbdns.SimpleRecord
|
||||||
initSerial uint64
|
initSerial uint64
|
||||||
inputSerial uint64
|
inputSerial uint64
|
||||||
inputUpdate nbdns.Config
|
inputUpdate nbdns.Config
|
||||||
shouldFail bool
|
shouldFail bool
|
||||||
expectedUpstreamMap registeredHandlerMap
|
expectedUpstreamMap registeredHandlerMap
|
||||||
expectedLocalMap registrationMap
|
expectedLocalQs []dns.Question
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Initial Config Should Succeed",
|
name: "Initial Config Should Succeed",
|
||||||
initLocalMap: make(registrationMap),
|
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
@ -159,30 +162,30 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
dummyHandler.id(): handlerWrapper{
|
dummyHandler.ID(): handlerWrapper{
|
||||||
domain: "netbird.cloud",
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||||
domain: nbdns.RootZone,
|
domain: nbdns.RootZone,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New Config Should Succeed",
|
name: "New Config Should Succeed",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: buildRecordKey(zoneRecords[0].Name, 1, 1),
|
domain: "netbird.cloud",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
@ -205,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{
|
expectedUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler("netbird.io", nameServers).id(): handlerWrapper{
|
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
|
||||||
domain: "netbird.io",
|
domain: "netbird.io",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@ -216,22 +219,22 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}},
|
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Smaller Config Serial Should Be Skipped",
|
name: "Smaller Config Serial Should Be Skipped",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 2,
|
initSerial: 2,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
ServiceEnable: true,
|
ServiceEnable: true,
|
||||||
CustomZones: []nbdns.CustomZone{
|
CustomZones: []nbdns.CustomZone{
|
||||||
@ -249,11 +252,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid NS Group Nameservers list Should Fail",
|
name: "Invalid NS Group Nameservers list Should Fail",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
ServiceEnable: true,
|
ServiceEnable: true,
|
||||||
CustomZones: []nbdns.CustomZone{
|
CustomZones: []nbdns.CustomZone{
|
||||||
@ -271,11 +274,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
shouldFail: true,
|
shouldFail: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid Custom Zone Records list Should Skip",
|
name: "Invalid Custom Zone Records list Should Skip",
|
||||||
initLocalMap: make(registrationMap),
|
initLocalRecords: []nbdns.SimpleRecord{},
|
||||||
initUpstreamMap: make(registeredHandlerMap),
|
initUpstreamMap: make(registeredHandlerMap),
|
||||||
initSerial: 0,
|
initSerial: 0,
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{
|
inputUpdate: nbdns.Config{
|
||||||
ServiceEnable: true,
|
ServiceEnable: true,
|
||||||
CustomZones: []nbdns.CustomZone{
|
CustomZones: []nbdns.CustomZone{
|
||||||
@ -290,17 +293,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
|
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{
|
||||||
domain: ".",
|
domain: ".",
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityDefault,
|
priority: PriorityDefault,
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty Config Should Succeed and Clean Maps",
|
name: "Empty Config Should Succeed and Clean Maps",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@ -310,13 +313,13 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: true},
|
inputUpdate: nbdns.Config{ServiceEnable: true},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: make(registeredHandlerMap),
|
||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalQs: []dns.Question{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Disabled Service Should clean map",
|
name: "Disabled Service Should clean map",
|
||||||
initLocalMap: registrationMap{"netbird.cloud": struct{}{}},
|
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
|
||||||
initUpstreamMap: registeredHandlerMap{
|
initUpstreamMap: registeredHandlerMap{
|
||||||
generateDummyHandler(zoneRecords[0].Name, nameServers).id(): handlerWrapper{
|
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: dummyHandler,
|
handler: dummyHandler,
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
@ -326,7 +329,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
inputUpdate: nbdns.Config{ServiceEnable: false},
|
inputUpdate: nbdns.Config{ServiceEnable: false},
|
||||||
expectedUpstreamMap: make(registeredHandlerMap),
|
expectedUpstreamMap: make(registeredHandlerMap),
|
||||||
expectedLocalMap: make(registrationMap),
|
expectedLocalQs: []dns.Question{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,7 +380,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
@ -399,15 +402,23 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dnsServer.localResolver.registeredMap) != len(testCase.expectedLocalMap) {
|
var responseMSG *dns.Msg
|
||||||
t.Fatalf("update local failed, registered map size is different than expected, want %d, got %d", len(testCase.expectedLocalMap), len(dnsServer.localResolver.registeredMap))
|
responseWriter := &test.MockResponseWriter{
|
||||||
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
|
responseMSG = m
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, q := range testCase.expectedLocalQs {
|
||||||
|
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
|
||||||
|
Question: []dns.Question{q},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for key := range testCase.expectedLocalMap {
|
if len(testCase.expectedLocalQs) > 0 {
|
||||||
_, found := dnsServer.localResolver.registeredMap[key]
|
assert.NotNil(t, responseMSG, "response message should not be nil")
|
||||||
if !found {
|
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
|
||||||
t.Fatalf("update local failed, key %s was not found in the localResolver.registeredMap: %#v", key, dnsServer.localResolver.registeredMap)
|
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -491,11 +502,12 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
dnsServer.dnsMuxMap = registeredHandlerMap{
|
dnsServer.dnsMuxMap = registeredHandlerMap{
|
||||||
"id1": handlerWrapper{
|
"id1": handlerWrapper{
|
||||||
domain: zoneRecords[0].Name,
|
domain: zoneRecords[0].Name,
|
||||||
handler: &localResolver{},
|
handler: &local.Resolver{},
|
||||||
priority: PriorityMatchDomain,
|
priority: PriorityMatchDomain,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dnsServer.localResolver.registeredMap = registrationMap{"netbird.cloud": struct{}{}}
|
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||||
|
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
||||||
dnsServer.updateSerial = 0
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
@ -582,7 +594,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
defer dnsServer.Stop()
|
defer dnsServer.Stop()
|
||||||
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
err = dnsServer.localResolver.RegisterRecord(zoneRecords[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@ -630,13 +642,11 @@ func TestDNSServerStartStop(t *testing.T) {
|
|||||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||||
hostManager := &mockHostConfigurator{}
|
hostManager := &mockHostConfigurator{}
|
||||||
server := DefaultServer{
|
server := DefaultServer{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
service: NewServiceViaMemory(&mocWGIface{}),
|
service: NewServiceViaMemory(&mocWGIface{}),
|
||||||
localResolver: &localResolver{
|
localResolver: local.NewResolver(),
|
||||||
registeredMap: make(registrationMap),
|
handlerChain: NewHandlerChain(),
|
||||||
},
|
hostManager: hostManager,
|
||||||
handlerChain: NewHandlerChain(),
|
|
||||||
hostManager: hostManager,
|
|
||||||
currentConfig: HostDNSConfig{
|
currentConfig: HostDNSConfig{
|
||||||
Domains: []DomainConfig{
|
Domains: []DomainConfig{
|
||||||
{false, "domain0", false},
|
{false, "domain0", false},
|
||||||
@ -1004,7 +1014,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tc.query, dns.TypeA)
|
r.SetQuestion(tc.query, dns.TypeA)
|
||||||
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
mh.On("ServeDNS", mock.Anything, r).Once()
|
mh.On("ServeDNS", mock.Anything, r).Once()
|
||||||
@ -1037,9 +1047,9 @@ type mockHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
func (m *mockHandler) stop() {}
|
func (m *mockHandler) Stop() {}
|
||||||
func (m *mockHandler) probeAvailability() {}
|
func (m *mockHandler) ProbeAvailability() {}
|
||||||
func (m *mockHandler) id() handlerID { return handlerID(m.Id) }
|
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||||
|
|
||||||
type mockService struct{}
|
type mockService struct{}
|
||||||
|
|
||||||
@ -1113,7 +1123,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
initialHandlers registeredHandlerMap
|
initialHandlers registeredHandlerMap
|
||||||
updates []handlerWrapper
|
updates []handlerWrapper
|
||||||
expectedHandlers map[string]string // map[handlerID]domain
|
expectedHandlers map[string]string // map[HandlerID]domain
|
||||||
description string
|
description string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@ -1409,7 +1419,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
|
|
||||||
// Check each expected handler
|
// Check each expected handler
|
||||||
for id, expectedDomain := range tt.expectedHandlers {
|
for id, expectedDomain := range tt.expectedHandlers {
|
||||||
handler, exists := server.dnsMuxMap[handlerID(id)]
|
handler, exists := server.dnsMuxMap[types.HandlerID(id)]
|
||||||
assert.True(t, exists, "Expected handler %s not found", id)
|
assert.True(t, exists, "Expected handler %s not found", id)
|
||||||
if exists {
|
if exists {
|
||||||
assert.Equal(t, expectedDomain, handler.domain,
|
assert.Equal(t, expectedDomain, handler.domain,
|
||||||
@ -1418,9 +1428,9 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify no unexpected handlers exist
|
// Verify no unexpected handlers exist
|
||||||
for handlerID := range server.dnsMuxMap {
|
for HandlerID := range server.dnsMuxMap {
|
||||||
_, expected := tt.expectedHandlers[string(handlerID)]
|
_, expected := tt.expectedHandlers[string(HandlerID)]
|
||||||
assert.True(t, expected, "Unexpected handler found: %s", handlerID)
|
assert.True(t, expected, "Unexpected handler found: %s", HandlerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the handlerChain state and order
|
// Verify the handlerChain state and order
|
||||||
@ -1696,7 +1706,7 @@ func TestExtraDomains(t *testing.T) {
|
|||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
wgInterface: &mocWGIface{},
|
wgInterface: &mocWGIface{},
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
@ -1781,7 +1791,7 @@ func TestExtraDomainsRefCounting(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
@ -1833,7 +1843,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
@ -1916,7 +1926,7 @@ func TestDomainCaseHandling(t *testing.T) {
|
|||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: NewHandlerChain(),
|
||||||
hostManager: mockHostConfig,
|
hostManager: mockHostConfig,
|
||||||
localResolver: &localResolver{},
|
localResolver: &local.Resolver{},
|
||||||
service: mockSvc,
|
service: mockSvc,
|
||||||
statusRecorder: peer.NewRecorder("test"),
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
extraDomains: make(map[domain.Domain]int),
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
@ -30,9 +30,12 @@ const (
|
|||||||
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
|
systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS"
|
||||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||||
|
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||||
systemdDbusResolvConfModeForeign = "foreign"
|
systemdDbusResolvConfModeForeign = "foreign"
|
||||||
|
|
||||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||||
|
|
||||||
|
dnsSecDisabled = "no"
|
||||||
)
|
)
|
||||||
|
|
||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
@ -95,9 +98,13 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
Family: unix.AF_INET,
|
Family: unix.AF_INET,
|
||||||
Address: ipAs4[:],
|
Address: ipAs4[:],
|
||||||
}
|
}
|
||||||
err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput})
|
if err = s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("set interface DNS server %s:%d: %w", config.ServerIP, config.ServerPort, err)
|
||||||
return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %w", config.ServerIP, config.ServerPort, err)
|
}
|
||||||
|
|
||||||
|
// We don't support dnssec. On some machines this is default on so we explicitly set it to off
|
||||||
|
if err = s.callLinkMethod(systemdDbusSetDNSSECMethodSuffix, dnsSecDisabled); err != nil {
|
||||||
|
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
26
client/internal/dns/test/mock.go
Normal file
26
client/internal/dns/test/mock.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockResponseWriter struct {
|
||||||
|
WriteMsgFunc func(m *dns.Msg) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||||
|
if rw.WriteMsgFunc != nil {
|
||||||
|
return rw.WriteMsgFunc(m)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
func (rw *MockResponseWriter) Close() error { return nil }
|
||||||
|
func (rw *MockResponseWriter) TsigStatus() error { return nil }
|
||||||
|
func (rw *MockResponseWriter) TsigTimersOnly(bool) {}
|
||||||
|
func (rw *MockResponseWriter) Hijack() {}
|
3
client/internal/dns/types/types.go
Normal file
3
client/internal/dns/types/types.go
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type HandlerID string
|
@ -19,6 +19,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
@ -81,21 +82,21 @@ func (u *upstreamResolverBase) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
func (u *upstreamResolverBase) id() handlerID {
|
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||||
servers := slices.Clone(u.upstreamServers)
|
servers := slices.Clone(u.upstreamServers)
|
||||||
slices.Sort(servers)
|
slices.Sort(servers)
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
hash.Write([]byte(u.domain + ":"))
|
hash.Write([]byte(u.domain + ":"))
|
||||||
hash.Write([]byte(strings.Join(servers, ",")))
|
hash.Write([]byte(strings.Join(servers, ",")))
|
||||||
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) stop() {
|
func (u *upstreamResolverBase) Stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
}
|
}
|
||||||
@ -198,9 +199,9 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// probeAvailability tests all upstream servers simultaneously and
|
// ProbeAvailability tests all upstream servers simultaneously and
|
||||||
// disables the resolver if none work
|
// disables the resolver if none work
|
||||||
func (u *upstreamResolverBase) probeAvailability() {
|
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||||
u.mutex.Lock()
|
u.mutex.Lock()
|
||||||
defer u.mutex.Unlock()
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||||
@ -66,7 +68,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var responseMSG *dns.Msg
|
var responseMSG *dns.Msg
|
||||||
responseWriter := &mockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
WriteMsgFunc: func(m *dns.Msg) error {
|
||||||
responseMSG = m
|
responseMSG = m
|
||||||
return nil
|
return nil
|
||||||
@ -130,7 +132,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
resolver.failsTillDeact = 0
|
resolver.failsTillDeact = 0
|
||||||
resolver.reactivatePeriod = time.Microsecond * 100
|
resolver.reactivatePeriod = time.Microsecond * 100
|
||||||
|
|
||||||
responseWriter := &mockResponseWriter{
|
responseWriter := &test.MockResponseWriter{
|
||||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@ -18,5 +17,4 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@ -13,6 +12,5 @@ type WGIface interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
|
||||||
GetInterfaceGUIDString() (string, error)
|
GetInterfaceGUIDString() (string, error)
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,8 @@ type DNSForwarder struct {
|
|||||||
|
|
||||||
dnsServer *dns.Server
|
dnsServer *dns.Server
|
||||||
mux *dns.ServeMux
|
mux *dns.ServeMux
|
||||||
|
tcpServer *dns.Server
|
||||||
|
tcpMux *dns.ServeMux
|
||||||
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
@ -50,22 +52,41 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
||||||
mux := dns.NewServeMux()
|
|
||||||
|
|
||||||
dnsServer := &dns.Server{
|
// UDP server
|
||||||
|
mux := dns.NewServeMux()
|
||||||
|
f.mux = mux
|
||||||
|
f.dnsServer = &dns.Server{
|
||||||
Addr: f.listenAddress,
|
Addr: f.listenAddress,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
f.dnsServer = dnsServer
|
// TCP server
|
||||||
f.mux = mux
|
tcpMux := dns.NewServeMux()
|
||||||
|
f.tcpMux = tcpMux
|
||||||
|
f.tcpServer = &dns.Server{
|
||||||
|
Addr: f.listenAddress,
|
||||||
|
Net: "tcp",
|
||||||
|
Handler: tcpMux,
|
||||||
|
}
|
||||||
|
|
||||||
f.UpdateDomains(entries)
|
f.UpdateDomains(entries)
|
||||||
|
|
||||||
return dnsServer.ListenAndServe()
|
errCh := make(chan error, 2)
|
||||||
}
|
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
||||||
|
errCh <- f.dnsServer.ListenAndServe()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
||||||
|
errCh <- f.tcpServer.ListenAndServe()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// return the first error we get (e.g. bind failure or shutdown)
|
||||||
|
return <-errCh
|
||||||
|
}
|
||||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
@ -77,31 +98,41 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
oldDomains := filterDomains(f.fwdEntries)
|
oldDomains := filterDomains(f.fwdEntries)
|
||||||
|
|
||||||
for _, d := range oldDomains {
|
for _, d := range oldDomains {
|
||||||
f.mux.HandleRemove(d.PunycodeString())
|
f.mux.HandleRemove(d.PunycodeString())
|
||||||
|
f.tcpMux.HandleRemove(d.PunycodeString())
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomains := filterDomains(entries)
|
newDomains := filterDomains(entries)
|
||||||
for _, d := range newDomains {
|
for _, d := range newDomains {
|
||||||
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
|
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
|
||||||
|
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
|
|
||||||
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
if f.dnsServer == nil {
|
var result *multierror.Error
|
||||||
return nil
|
|
||||||
|
if f.dnsServer != nil {
|
||||||
|
if err := f.dnsServer.ShutdownContext(ctx); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return f.dnsServer.ShutdownContext(ctx)
|
if f.tcpServer != nil {
|
||||||
|
if err := f.tcpServer.ShutdownContext(ctx); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||||
@ -123,20 +154,53 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.handleDNSError(w, resp, domain, err)
|
f.handleDNSError(w, query, resp, domain, err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(domain, ips)
|
f.updateInternalState(domain, ips)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
|
||||||
|
resp := f.handleDNSQuery(w, query)
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opt := query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
// client advertised a larger EDNS0 buffer
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// if our response is too big, truncate and set the TC bit
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
|
resp := f.handleDNSQuery(w, query)
|
||||||
|
if resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
@ -179,7 +243,7 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
|
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@ -191,7 +255,7 @@ func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domai
|
|||||||
}
|
}
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
if dnsErr.Server != "" {
|
||||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
|
||||||
} else {
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,7 @@ type Manager struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
fwRules []firewall.Rule
|
fwRules []firewall.Rule
|
||||||
|
tcpRules []firewall.Rule
|
||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,6 +108,13 @@ func (m *Manager) allowDNSFirewall() error {
|
|||||||
}
|
}
|
||||||
m.fwRules = dnsRules
|
m.fwRules = dnsRules
|
||||||
|
|
||||||
|
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.tcpRules = tcpRules
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,7 +125,13 @@ func (m *Manager) dropDNSFirewall() error {
|
|||||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, rule := range m.tcpRules {
|
||||||
|
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.fwRules = nil
|
m.fwRules = nil
|
||||||
|
m.tcpRules = nil
|
||||||
return nberrors.FormatErrorOrNil(mErr)
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ import (
|
|||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
@ -122,6 +123,8 @@ type EngineConfig struct {
|
|||||||
DisableFirewall bool
|
DisableFirewall bool
|
||||||
|
|
||||||
BlockLANAccess bool
|
BlockLANAccess bool
|
||||||
|
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
@ -134,6 +137,8 @@ type Engine struct {
|
|||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerStore *peerstore.Store
|
peerStore *peerstore.Store
|
||||||
|
|
||||||
|
connMgr *ConnMgr
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
beforePeerHook nbnet.AddHookFunc
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
afterPeerHook nbnet.RemoveHookFunc
|
||||||
|
|
||||||
@ -170,7 +175,8 @@ type Engine struct {
|
|||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
|
||||||
firewall firewallManager.Manager
|
firewall firewallManager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
@ -262,6 +268,10 @@ func (e *Engine) Stop() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if e.connMgr != nil {
|
||||||
|
e.connMgr.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// stopping network monitor first to avoid starting the engine again
|
// stopping network monitor first to avoid starting the engine again
|
||||||
if e.networkMonitor != nil {
|
if e.networkMonitor != nil {
|
||||||
e.networkMonitor.Stop()
|
e.networkMonitor.Stop()
|
||||||
@ -297,8 +307,7 @@ func (e *Engine) Stop() error {
|
|||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
if err := e.removeAllPeers(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,8 +414,7 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
err = e.wgInterfaceCreate()
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
return fmt.Errorf("create wg interface: %w", err)
|
return fmt.Errorf("create wg interface: %w", err)
|
||||||
@ -442,6 +450,11 @@ func (e *Engine) Start() error {
|
|||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher()
|
||||||
|
|
||||||
|
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
|
||||||
|
e.connMgr.Start(e.ctx)
|
||||||
|
|
||||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||||
e.srWatcher.Start()
|
e.srWatcher.Start()
|
||||||
|
|
||||||
@ -450,7 +463,6 @@ func (e *Engine) Start() error {
|
|||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -550,6 +562,16 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
peerPubKey := p.GetWgPubKey()
|
peerPubKey := p.GetWgPubKey()
|
||||||
|
currentPeer, ok := e.peerStore.PeerConn(peerPubKey)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentPeer.AgentVersionString() != p.AgentVersion {
|
||||||
|
modified = append(modified, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
@ -559,8 +581,7 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn())
|
if err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -621,16 +642,11 @@ func (e *Engine) removePeer(peerKey string) error {
|
|||||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
e.connMgr.RemovePeerConn(peerKey)
|
||||||
err := e.statusRecorder.RemovePeer(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn, exists := e.peerStore.Remove(peerKey)
|
err := e.statusRecorder.RemovePeer(peerKey)
|
||||||
if exists {
|
if err != nil {
|
||||||
conn.Close()
|
log.Warnf("received error when removing peer %s from status recorder: %v", peerKey, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -952,12 +968,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||||
|
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||||
log.Errorf("failed to update local IPs: %v", err)
|
log.Errorf("failed to update local IPs: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
||||||
|
// then the mgmt server is older than the client, and we need to allow all traffic for routes.
|
||||||
|
// This needs to be toggled before applying routes.
|
||||||
|
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
||||||
|
if err := e.firewall.SetLegacyManagement(isLegacy); err != nil {
|
||||||
|
log.Errorf("failed to set legacy management flag: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
@ -976,7 +1004,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||||
|
|
||||||
// Ingress forward rules
|
// Ingress forward rules
|
||||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||||
|
if err != nil {
|
||||||
log.Errorf("failed to update forward rules, err: %v", err)
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1022,6 +1051,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
|
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers())
|
||||||
|
e.connMgr.SetExcludeList(excludedLazyPeers)
|
||||||
|
|
||||||
protoDNSConfig := networkMap.GetDNSConfig()
|
protoDNSConfig := networkMap.GetDNSConfig()
|
||||||
if protoDNSConfig == nil {
|
if protoDNSConfig == nil {
|
||||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||||
@ -1155,7 +1188,7 @@ func (e *Engine) updateOfflinePeers(offlinePeers []*mgmProto.RemotePeerConfig) {
|
|||||||
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
IP: strings.Join(offlinePeer.GetAllowedIps(), ","),
|
||||||
PubKey: offlinePeer.GetWgPubKey(),
|
PubKey: offlinePeer.GetWgPubKey(),
|
||||||
FQDN: offlinePeer.GetFqdn(),
|
FQDN: offlinePeer.GetFqdn(),
|
||||||
ConnStatus: peer.StatusDisconnected,
|
ConnStatus: peer.StatusIdle,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
Mux: new(sync.RWMutex),
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
@ -1191,12 +1224,17 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
peerIPs = append(peerIPs, allowedNetIP)
|
peerIPs = append(peerIPs, allowedNetIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := e.createPeerConn(peerKey, peerIPs)
|
conn, err := e.createPeerConn(peerKey, peerIPs, peerConfig.AgentVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create peer connection: %w", err)
|
return fmt.Errorf("create peer connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn, peerIPs[0].Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
}
|
}
|
||||||
@ -1205,17 +1243,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Open()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) {
|
func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentVersion string) (*peer.Conn, error) {
|
||||||
log.Debugf("creating peer connection %s", pubKey)
|
log.Debugf("creating peer connection %s", pubKey)
|
||||||
|
|
||||||
wgConfig := peer.WgConfig{
|
wgConfig := peer.WgConfig{
|
||||||
@ -1229,11 +1260,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
|||||||
// randomize connection timeout
|
// randomize connection timeout
|
||||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||||
config := peer.ConnConfig{
|
config := peer.ConnConfig{
|
||||||
Key: pubKey,
|
Key: pubKey,
|
||||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||||
Timeout: timeout,
|
AgentVersion: agentVersion,
|
||||||
WgConfig: wgConfig,
|
Timeout: timeout,
|
||||||
LocalWgPort: e.config.WgPort,
|
WgConfig: wgConfig,
|
||||||
|
LocalWgPort: e.config.WgPort,
|
||||||
RosenpassConfig: peer.RosenpassConfig{
|
RosenpassConfig: peer.RosenpassConfig{
|
||||||
PubKey: e.getRosenpassPubKey(),
|
PubKey: e.getRosenpassPubKey(),
|
||||||
Addr: e.getRosenpassAddr(),
|
Addr: e.getRosenpassAddr(),
|
||||||
@ -1249,7 +1281,16 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
|
StatusRecorder: e.statusRecorder,
|
||||||
|
Signaler: e.signaler,
|
||||||
|
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||||
|
RelayManager: e.relayManager,
|
||||||
|
SrWatcher: e.srWatcher,
|
||||||
|
Semaphore: e.connSemaphore,
|
||||||
|
PeerConnDispatcher: e.peerConnDispatcher,
|
||||||
|
}
|
||||||
|
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1270,7 +1311,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||||
}
|
}
|
||||||
@ -1578,13 +1619,39 @@ func (e *Engine) getRosenpassAddr() string {
|
|||||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||||
// and updates the status recorder with the latest states.
|
// and updates the status recorder with the latest states.
|
||||||
func (e *Engine) RunHealthProbes() bool {
|
func (e *Engine) RunHealthProbes() bool {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
|
||||||
signalHealthy := e.signal.IsHealthy()
|
signalHealthy := e.signal.IsHealthy()
|
||||||
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
||||||
|
|
||||||
managementHealthy := e.mgmClient.IsHealthy()
|
managementHealthy := e.mgmClient.IsHealthy()
|
||||||
log.Debugf("management health check: healthy=%t", managementHealthy)
|
log.Debugf("management health check: healthy=%t", managementHealthy)
|
||||||
|
|
||||||
results := append(e.probeSTUNs(), e.probeTURNs()...)
|
stuns := slices.Clone(e.STUNs)
|
||||||
|
turns := slices.Clone(e.TURNs)
|
||||||
|
|
||||||
|
if e.wgInterface != nil {
|
||||||
|
stats, err := e.wgInterface.GetStats()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get wireguard stats: %v", err)
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, key := range e.peerStore.PeersPubKey() {
|
||||||
|
// wgStats could be zero value, in which case we just reset the stats
|
||||||
|
wgStats, ok := stats[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
||||||
|
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
results := e.probeICE(stuns, turns)
|
||||||
e.statusRecorder.UpdateRelayStates(results)
|
e.statusRecorder.UpdateRelayStates(results)
|
||||||
|
|
||||||
relayHealthy := true
|
relayHealthy := true
|
||||||
@ -1596,37 +1663,16 @@ func (e *Engine) RunHealthProbes() bool {
|
|||||||
}
|
}
|
||||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||||
|
|
||||||
for _, key := range e.peerStore.PeersPubKey() {
|
|
||||||
wgStats, err := e.wgInterface.GetStats(key)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// wgStats could be zero value, in which case we just reset the stats
|
|
||||||
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
|
|
||||||
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||||
return allHealthy
|
return allHealthy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||||
e.syncMsgMux.Lock()
|
return append(
|
||||||
stuns := slices.Clone(e.STUNs)
|
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||||
e.syncMsgMux.Unlock()
|
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)...,
|
||||||
|
)
|
||||||
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
turns := slices.Clone(e.TURNs)
|
|
||||||
e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// restartEngine restarts the engine by cancelling the client context
|
// restartEngine restarts the engine by cancelling the client context
|
||||||
@ -1813,21 +1859,21 @@ func (e *Engine) Address() (netip.Addr, error) {
|
|||||||
return ip.Unmap(), nil
|
return ip.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||||
if e.firewall == nil {
|
if e.firewall == nil {
|
||||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(rules) == 0 {
|
if len(rules) == 0 {
|
||||||
if e.ingressGatewayMgr == nil {
|
if e.ingressGatewayMgr == nil {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := e.ingressGatewayMgr.Close()
|
err := e.ingressGatewayMgr.Close()
|
||||||
e.ingressGatewayMgr = nil
|
e.ingressGatewayMgr = nil
|
||||||
e.statusRecorder.SetIngressGwMgr(nil)
|
e.statusRecorder.SetIngressGwMgr(nil)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.ingressGatewayMgr == nil {
|
if e.ingressGatewayMgr == nil {
|
||||||
@ -1878,7 +1924,33 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
|||||||
log.Errorf("failed to update forwarding rules: %v", err)
|
log.Errorf("failed to update forwarding rules: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return forwardingRules, nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) []string {
|
||||||
|
excludedPeers := make([]string, 0)
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.Peer == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
|
||||||
|
excludedPeers = append(excludedPeers, r.Peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range rules {
|
||||||
|
ip := r.TranslatedAddress
|
||||||
|
for _, p := range peers {
|
||||||
|
for _, allowedIP := range p.GetAllowedIps() {
|
||||||
|
if allowedIP != ip.String() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
|
||||||
|
excludedPeers = append(excludedPeers, p.GetWgPubKey())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return excludedPeers
|
||||||
}
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
|
@ -28,8 +28,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@ -38,6 +36,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
@ -53,6 +52,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
@ -93,7 +93,7 @@ type MockWGIface struct {
|
|||||||
GetFilterFunc func() device.PacketFilter
|
GetFilterFunc func() device.PacketFilter
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
|
GetStatsFunc func() (map[string]configurer.WGStats, error)
|
||||||
GetInterfaceGUIDStringFunc func() (string, error)
|
GetInterfaceGUIDStringFunc func() (string, error)
|
||||||
GetProxyFunc func() wgproxy.Proxy
|
GetProxyFunc func() wgproxy.Proxy
|
||||||
GetNetFunc func() *netstack.Net
|
GetNetFunc func() *netstack.Net
|
||||||
@ -171,8 +171,8 @@ func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
|
|||||||
return m.GetWGDeviceFunc()
|
return m.GetWGDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (m *MockWGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||||
return m.GetStatsFunc(peerKey)
|
return m.GetStatsFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
|
||||||
@ -378,6 +378,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
engine.wgInterface = wgIface
|
engine.wgInterface = wgIface
|
||||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||||
@ -400,6 +403,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||||
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher())
|
||||||
|
engine.connMgr.Start(ctx)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
@ -770,6 +775,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
|
|
||||||
engine.routeManager = mockRouteManager
|
engine.routeManager = mockRouteManager
|
||||||
engine.dnsServer = &dns.MockServer{}
|
engine.dnsServer = &dns.MockServer{}
|
||||||
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||||
|
engine.connMgr.Start(ctx)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
exitErr := engine.Stop()
|
exitErr := engine.Stop()
|
||||||
@ -966,6 +973,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
engine.dnsServer = mockDNSServer
|
engine.dnsServer = mockDNSServer
|
||||||
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher())
|
||||||
|
engine.connMgr.Start(ctx)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
exitErr := engine.Stop()
|
exitErr := engine.Stop()
|
||||||
@ -1476,7 +1485,7 @@ func getConnectedPeers(e *Engine) int {
|
|||||||
i := 0
|
i := 0
|
||||||
for _, id := range e.peerStore.PeersPubKey() {
|
for _, id := range e.peerStore.PeersPubKey() {
|
||||||
conn, _ := e.peerStore.PeerConn(id)
|
conn, _ := e.peerStore.PeerConn(id)
|
||||||
if conn.Status() == peer.StatusConnected {
|
if conn.IsConnected() {
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,6 @@ type wgIfaceBase interface {
|
|||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
}
|
}
|
||||||
|
9
client/internal/lazyconn/activity/listen_ip.go
Normal file
9
client/internal/lazyconn/activity/listen_ip.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package activity
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
var (
|
||||||
|
listenIP = net.IP{127, 0, 0, 1}
|
||||||
|
)
|
10
client/internal/lazyconn/activity/listen_ip_linux.go
Normal file
10
client/internal/lazyconn/activity/listen_ip_linux.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package activity
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// use this ip to avoid eBPF proxy congestion
|
||||||
|
listenIP = net.IP{127, 0, 1, 1}
|
||||||
|
)
|
106
client/internal/lazyconn/activity/listener.go
Normal file
106
client/internal/lazyconn/activity/listener.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||||
|
type Listener struct {
|
||||||
|
wgIface lazyconn.WGIface
|
||||||
|
peerCfg lazyconn.PeerConfig
|
||||||
|
conn *net.UDPConn
|
||||||
|
endpoint *net.UDPAddr
|
||||||
|
done sync.Mutex
|
||||||
|
|
||||||
|
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||||
|
d := &Listener{
|
||||||
|
wgIface: wgIface,
|
||||||
|
peerCfg: cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := d.newConn()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
|
||||||
|
}
|
||||||
|
d.conn = conn
|
||||||
|
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
|
||||||
|
if err := d.createEndpoint(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d.done.Lock()
|
||||||
|
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Listener) ReadPackets() {
|
||||||
|
for {
|
||||||
|
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||||
|
if err != nil {
|
||||||
|
if d.isClosed.Load() {
|
||||||
|
d.peerCfg.Log.Debugf("exit from activity listener")
|
||||||
|
} else {
|
||||||
|
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < 1 {
|
||||||
|
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.removeEndpoint(); err != nil {
|
||||||
|
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
|
||||||
|
d.done.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Listener) Close() {
|
||||||
|
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String())
|
||||||
|
d.isClosed.Store(true)
|
||||||
|
|
||||||
|
if err := d.conn.Close(); err != nil {
|
||||||
|
d.peerCfg.Log.Errorf("failed to close UDP listener: %s", err)
|
||||||
|
}
|
||||||
|
d.done.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Listener) removeEndpoint() error {
|
||||||
|
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||||
|
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Listener) createEndpoint() error {
|
||||||
|
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
||||||
|
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Listener) newConn() (*net.UDPConn, error) {
|
||||||
|
addr := &net.UDPAddr{
|
||||||
|
Port: 0,
|
||||||
|
IP: listenIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create activity listener on %s: %s", addr, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
41
client/internal/lazyconn/activity/listener_test.go
Normal file
41
client/internal/lazyconn/activity/listener_test.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewListener(t *testing.T) {
|
||||||
|
peer := &MocPeer{
|
||||||
|
PeerID: "examplePublicKey1",
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
Log: log.WithField("peer", "examplePublicKey1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := NewListener(MocWGIface{}, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chanClosed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(chanClosed)
|
||||||
|
l.ReadPackets()
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
l.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-chanClosed:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
}
|
95
client/internal/lazyconn/activity/manager.go
Normal file
95
client/internal/lazyconn/activity/manager.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
OnActivityChan chan peerid.ConnID
|
||||||
|
|
||||||
|
wgIface lazyconn.WGIface
|
||||||
|
|
||||||
|
peers map[peerid.ConnID]*Listener
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(wgIface lazyconn.WGIface) *Manager {
|
||||||
|
m := &Manager{
|
||||||
|
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||||
|
wgIface: wgIface,
|
||||||
|
peers: make(map[peerid.ConnID]*Listener),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := m.peers[peerCfg.PeerConnID]; ok {
|
||||||
|
log.Warnf("activity listener already exists for: %s", peerCfg.PublicKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewListener(m.wgIface, peerCfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.peers[peerCfg.PeerConnID] = listener
|
||||||
|
|
||||||
|
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
listener, ok := m.peers[peerConnID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("removing activity listener")
|
||||||
|
delete(m.peers, peerConnID)
|
||||||
|
listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Close() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
close(m.done)
|
||||||
|
for peerID, listener := range m.peers {
|
||||||
|
delete(m.peers, peerID)
|
||||||
|
listener.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
|
||||||
|
listener.ReadPackets()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
if _, ok := m.peers[peerConnID]; !ok {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(m.peers, peerConnID)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.notify(peerConnID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) notify(peerConnID peerid.ConnID) {
|
||||||
|
select {
|
||||||
|
case <-m.done:
|
||||||
|
case m.OnActivityChan <- peerConnID:
|
||||||
|
}
|
||||||
|
}
|
162
client/internal/lazyconn/activity/manager_test.go
Normal file
162
client/internal/lazyconn/activity/manager_test.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocPeer struct {
|
||||||
|
PeerID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||||
|
return peerid.ConnID(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MocWGIface struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MocWGIface) RemovePeer(string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||||
|
mocWgInterface := &MocWGIface{}
|
||||||
|
|
||||||
|
peer1 := &MocPeer{
|
||||||
|
PeerID: "examplePublicKey1",
|
||||||
|
}
|
||||||
|
mgr := NewManager(mocWgInterface)
|
||||||
|
defer mgr.Close()
|
||||||
|
peerCfg1 := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer1.PeerID,
|
||||||
|
PeerConnID: peer1.ConnID(),
|
||||||
|
Log: log.WithField("peer", "examplePublicKey1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||||
|
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||||
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case peerConnID := <-mgr.OnActivityChan:
|
||||||
|
if peerConnID != peerCfg1.PeerConnID {
|
||||||
|
t.Fatalf("unexpected peerConnID: %v", peerConnID)
|
||||||
|
}
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_RemovePeerActivity(t *testing.T) {
|
||||||
|
mocWgInterface := &MocWGIface{}
|
||||||
|
|
||||||
|
peer1 := &MocPeer{
|
||||||
|
PeerID: "examplePublicKey1",
|
||||||
|
}
|
||||||
|
mgr := NewManager(mocWgInterface)
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
peerCfg1 := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer1.PeerID,
|
||||||
|
PeerConnID: peer1.ConnID(),
|
||||||
|
Log: log.WithField("peer", "examplePublicKey1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||||
|
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
||||||
|
|
||||||
|
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||||
|
|
||||||
|
if err := trigger(addr); err != nil {
|
||||||
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-mgr.OnActivityChan:
|
||||||
|
t.Fatal("should not have active activity")
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_MultiPeerActivity(t *testing.T) {
|
||||||
|
mocWgInterface := &MocWGIface{}
|
||||||
|
|
||||||
|
peer1 := &MocPeer{
|
||||||
|
PeerID: "examplePublicKey1",
|
||||||
|
}
|
||||||
|
mgr := NewManager(mocWgInterface)
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
peerCfg1 := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer1.PeerID,
|
||||||
|
PeerConnID: peer1.ConnID(),
|
||||||
|
Log: log.WithField("peer", "examplePublicKey1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
peer2 := &MocPeer{}
|
||||||
|
peerCfg2 := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer2.PeerID,
|
||||||
|
PeerConnID: peer2.ConnID(),
|
||||||
|
Log: log.WithField("peer", "examplePublicKey2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||||
|
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.MonitorPeerActivity(peerCfg2); err != nil {
|
||||||
|
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||||
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||||
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
select {
|
||||||
|
case <-mgr.OnActivityChan:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for activity")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func trigger(addr string) error {
|
||||||
|
// Create a connection to the destination UDP address and port
|
||||||
|
conn, err := net.Dial("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Write the bytes to the UDP connection
|
||||||
|
_, err = conn.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
32
client/internal/lazyconn/doc.go
Normal file
32
client/internal/lazyconn/doc.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
Package lazyconn provides mechanisms for managing lazy connections, which activate on demand to optimize resource usage and establish connections efficiently.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The package includes a `Manager` component responsible for:
|
||||||
|
- Managing lazy connections activated on-demand
|
||||||
|
- Managing inactivity monitors for lazy connections (based on peer disconnection events)
|
||||||
|
- Maintaining a list of excluded peers that should always have permanent connections
|
||||||
|
- Handling remote peer connection initiatives based on peer signaling
|
||||||
|
|
||||||
|
## Thread-Safe Operations
|
||||||
|
|
||||||
|
The `Manager` ensures thread safety across multiple operations, categorized by caller:
|
||||||
|
|
||||||
|
- **Engine (single goroutine)**:
|
||||||
|
- `AddPeer`: Adds a peer to the connection manager.
|
||||||
|
- `RemovePeer`: Removes a peer from the connection manager.
|
||||||
|
- `ActivatePeer`: Activates a lazy connection for a peer. This come from Signal client
|
||||||
|
- `ExcludePeer`: Marks peers for a permanent connection. Like router peers and other peers that should always have a connection.
|
||||||
|
|
||||||
|
- **Connection Dispatcher (any peer routine)**:
|
||||||
|
- `onPeerConnected`: Suspend the inactivity monitor for an active peer connection.
|
||||||
|
- `onPeerDisconnected`: Starts the inactivity monitor for a disconnected peer.
|
||||||
|
|
||||||
|
- **Activity Manager**:
|
||||||
|
- `onPeerActivity`: Run peer.Open(context).
|
||||||
|
|
||||||
|
- **Inactivity Monitor**:
|
||||||
|
- `onPeerInactivityTimedOut`: Close peer connection and restart activity monitor.
|
||||||
|
*/
|
||||||
|
package lazyconn
|
26
client/internal/lazyconn/env.go
Normal file
26
client/internal/lazyconn/env.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package lazyconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
|
||||||
|
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsLazyConnEnabledByEnv() bool {
|
||||||
|
val := os.Getenv(EnvEnableLazyConn)
|
||||||
|
if val == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
enabled, err := strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return enabled
|
||||||
|
}
|
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
70
client/internal/lazyconn/inactivity/inactivity.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package inactivity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
peer "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
|
||||||
|
MinimumInactivityThreshold = 3 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type Monitor struct {
|
||||||
|
id peer.ConnID
|
||||||
|
timer *time.Timer
|
||||||
|
cancel context.CancelFunc
|
||||||
|
inactivityThreshold time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
|
||||||
|
i := &Monitor{
|
||||||
|
id: peerID,
|
||||||
|
timer: time.NewTimer(0),
|
||||||
|
inactivityThreshold: threshold,
|
||||||
|
}
|
||||||
|
i.timer.Stop()
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
|
||||||
|
i.timer.Reset(i.inactivityThreshold)
|
||||||
|
defer i.timer.Stop()
|
||||||
|
|
||||||
|
ctx, i.cancel = context.WithCancel(ctx)
|
||||||
|
defer func() {
|
||||||
|
defer i.cancel()
|
||||||
|
select {
|
||||||
|
case <-i.timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-i.timer.C:
|
||||||
|
select {
|
||||||
|
case timeoutChan <- i.id:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Monitor) Stop() {
|
||||||
|
if i.cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
i.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Monitor) PauseTimer() {
|
||||||
|
i.timer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Monitor) ResetTimer() {
|
||||||
|
i.timer.Reset(i.inactivityThreshold)
|
||||||
|
}
|
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
156
client/internal/lazyconn/inactivity/inactivity_test.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package inactivity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocPeer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocPeer) ConnID() peerid.ConnID {
|
||||||
|
return peerid.ConnID(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInactivityMonitor(t *testing.T) {
|
||||||
|
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer testTimeoutCancel()
|
||||||
|
|
||||||
|
p := &MocPeer{}
|
||||||
|
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||||
|
|
||||||
|
timeoutChan := make(chan peerid.ConnID)
|
||||||
|
|
||||||
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(exitChan)
|
||||||
|
im.Start(tCtx, timeoutChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timeoutChan:
|
||||||
|
case <-tCtx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-exitChan:
|
||||||
|
case <-tCtx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReuseInactivityMonitor(t *testing.T) {
|
||||||
|
p := &MocPeer{}
|
||||||
|
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
|
||||||
|
|
||||||
|
timeoutChan := make(chan peerid.ConnID)
|
||||||
|
|
||||||
|
for i := 2; i > 0; i-- {
|
||||||
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
|
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(exitChan)
|
||||||
|
im.Start(testTimeoutCtx, timeoutChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timeoutChan:
|
||||||
|
case <-testTimeoutCtx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-exitChan:
|
||||||
|
case <-testTimeoutCtx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
testTimeoutCancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopInactivityMonitor(t *testing.T) {
|
||||||
|
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer testTimeoutCancel()
|
||||||
|
|
||||||
|
p := &MocPeer{}
|
||||||
|
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
|
||||||
|
|
||||||
|
timeoutChan := make(chan peerid.ConnID)
|
||||||
|
|
||||||
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(exitChan)
|
||||||
|
im.Start(tCtx, timeoutChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
im.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-timeoutChan:
|
||||||
|
t.Fatal("unexpected timeout")
|
||||||
|
case <-exitChan:
|
||||||
|
case <-tCtx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPauseInactivityMonitor(t *testing.T) {
|
||||||
|
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer testTimeoutCancel()
|
||||||
|
|
||||||
|
p := &MocPeer{}
|
||||||
|
trashHold := time.Second * 3
|
||||||
|
im := NewInactivityMonitor(p.ConnID(), trashHold)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
timeoutChan := make(chan peerid.ConnID)
|
||||||
|
|
||||||
|
exitChan := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(exitChan)
|
||||||
|
im.Start(ctx, timeoutChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second) // grant time to start the monitor
|
||||||
|
im.PauseTimer()
|
||||||
|
|
||||||
|
// check to do not receive timeout
|
||||||
|
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
|
||||||
|
defer thresholdCancel()
|
||||||
|
select {
|
||||||
|
case <-exitChan:
|
||||||
|
t.Fatal("unexpected exit")
|
||||||
|
case <-timeoutChan:
|
||||||
|
t.Fatal("unexpected timeout")
|
||||||
|
case <-thresholdCtx.Done():
|
||||||
|
// test ok
|
||||||
|
case <-tCtx.Done():
|
||||||
|
t.Fatal("test timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test reset timer
|
||||||
|
im.ResetTimer()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-tCtx.Done():
|
||||||
|
t.Fatal("test timed out")
|
||||||
|
case <-exitChan:
|
||||||
|
t.Fatal("unexpected exit")
|
||||||
|
case <-timeoutChan:
|
||||||
|
// expected timeout
|
||||||
|
}
|
||||||
|
}
|
404
client/internal/lazyconn/manager/manager.go
Normal file
404
client/internal/lazyconn/manager/manager.go
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
watcherActivity watcherType = iota
|
||||||
|
watcherInactivity
|
||||||
|
)
|
||||||
|
|
||||||
|
type watcherType int
|
||||||
|
|
||||||
|
type managedPeer struct {
|
||||||
|
peerCfg *lazyconn.PeerConfig
|
||||||
|
expectedWatcher watcherType
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
InactivityThreshold *time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager manages lazy connections
|
||||||
|
// It is responsible for:
|
||||||
|
// - Managing lazy connections activated on-demand
|
||||||
|
// - Managing inactivity monitors for lazy connections (based on peer disconnection events)
|
||||||
|
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||||
|
// - Handling connection establishment based on peer signaling
|
||||||
|
type Manager struct {
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
connStateDispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
inactivityThreshold time.Duration
|
||||||
|
|
||||||
|
connStateListener *dispatcher.ConnectionListener
|
||||||
|
managedPeers map[string]*lazyconn.PeerConfig
|
||||||
|
managedPeersByConnID map[peerid.ConnID]*managedPeer
|
||||||
|
excludes map[string]lazyconn.PeerConfig
|
||||||
|
managedPeersMu sync.Mutex
|
||||||
|
|
||||||
|
activityManager *activity.Manager
|
||||||
|
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
|
||||||
|
|
||||||
|
cancel context.CancelFunc
|
||||||
|
onInactive chan peerid.ConnID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
|
||||||
|
log.Infof("setup lazy connection service")
|
||||||
|
m := &Manager{
|
||||||
|
peerStore: peerStore,
|
||||||
|
connStateDispatcher: connStateDispatcher,
|
||||||
|
inactivityThreshold: inactivity.DefaultInactivityThreshold,
|
||||||
|
managedPeers: make(map[string]*lazyconn.PeerConfig),
|
||||||
|
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
|
||||||
|
excludes: make(map[string]lazyconn.PeerConfig),
|
||||||
|
activityManager: activity.NewManager(wgIface),
|
||||||
|
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
|
||||||
|
onInactive: make(chan peerid.ConnID),
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.InactivityThreshold != nil {
|
||||||
|
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold {
|
||||||
|
m.inactivityThreshold = *config.InactivityThreshold
|
||||||
|
} else {
|
||||||
|
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.connStateListener = &dispatcher.ConnectionListener{
|
||||||
|
OnConnected: m.onPeerConnected,
|
||||||
|
OnDisconnected: m.onPeerDisconnected,
|
||||||
|
}
|
||||||
|
|
||||||
|
connStateDispatcher.AddListener(m.connStateListener)
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the manager and listens for peer activity and inactivity events
|
||||||
|
func (m *Manager) Start(ctx context.Context) {
|
||||||
|
defer m.close()
|
||||||
|
|
||||||
|
ctx, m.cancel = context.WithCancel(ctx)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||||
|
m.onPeerActivity(ctx, peerConnID)
|
||||||
|
case peerConnID := <-m.onInactive:
|
||||||
|
m.onPeerInactivityTimedOut(peerConnID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExcludePeer marks peers for a permanent connection
|
||||||
|
// It removes peers from the managed list if they are added to the exclude list
|
||||||
|
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
|
||||||
|
// this case, we suppose that the connection status is connected or connecting.
|
||||||
|
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
|
||||||
|
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
added := make([]string, 0)
|
||||||
|
excludes := make(map[string]lazyconn.PeerConfig, len(peerConfigs))
|
||||||
|
|
||||||
|
for _, peerCfg := range peerConfigs {
|
||||||
|
log.Infof("update excluded lazy connection list with peer: %s", peerCfg.PublicKey)
|
||||||
|
excludes[peerCfg.PublicKey] = peerCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// if a peer is newly added to the exclude list, remove from the managed peers list
|
||||||
|
for pubKey, peerCfg := range excludes {
|
||||||
|
if _, wasExcluded := m.excludes[pubKey]; wasExcluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
added = append(added, pubKey)
|
||||||
|
peerCfg.Log.Infof("peer newly added to lazy connection exclude list")
|
||||||
|
m.removePeer(pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if a peer has been removed from exclude list then it should be added to the managed peers
|
||||||
|
for pubKey, peerCfg := range m.excludes {
|
||||||
|
if _, stillExcluded := excludes[pubKey]; stillExcluded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerCfg.Log.Infof("peer removed from lazy connection exclude list")
|
||||||
|
|
||||||
|
if err := m.addActivePeer(ctx, peerCfg); err != nil {
|
||||||
|
log.Errorf("failed to add peer to lazy connection manager: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.excludes = excludes
|
||||||
|
return added
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
peerCfg.Log.Debugf("adding peer to lazy connection manager")
|
||||||
|
|
||||||
|
_, exists := m.excludes[peerCfg.PublicKey]
|
||||||
|
if exists {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||||
|
peerCfg.Log.Warnf("peer already managed")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.activityManager.MonitorPeerActivity(peerCfg); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||||
|
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||||
|
|
||||||
|
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||||
|
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||||
|
peerCfg: &peerCfg,
|
||||||
|
expectedWatcher: watcherActivity,
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddActivePeers adds a list of peers to the lazy connection manager
|
||||||
|
// suppose these peers was in connected or in connecting states
|
||||||
|
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
for _, cfg := range peerCfg {
|
||||||
|
if _, ok := m.managedPeers[cfg.PublicKey]; ok {
|
||||||
|
cfg.Log.Errorf("peer already managed")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.addActivePeer(ctx, cfg); err != nil {
|
||||||
|
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RemovePeer(peerID string) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
m.removePeer(peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivatePeer activates a peer connection when a signal message is received
|
||||||
|
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
cfg, ok := m.managedPeers[peerID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// signal messages coming continuously after success activation, with this avoid the multiple activation
|
||||||
|
if mp.expectedWatcher == watcherInactivity {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.expectedWatcher = watcherInactivity
|
||||||
|
|
||||||
|
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||||
|
|
||||||
|
im, ok := m.inactivityMonitors[cfg.PeerConnID]
|
||||||
|
if !ok {
|
||||||
|
cfg.Log.Errorf("inactivity monitor not found for peer")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("starting inactivity monitor")
|
||||||
|
go im.Start(ctx, m.onInactive)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
|
||||||
|
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
|
||||||
|
peerCfg.Log.Warnf("peer already managed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
|
||||||
|
m.inactivityMonitors[peerCfg.PeerConnID] = im
|
||||||
|
|
||||||
|
m.managedPeers[peerCfg.PublicKey] = &peerCfg
|
||||||
|
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
|
||||||
|
peerCfg: &peerCfg,
|
||||||
|
expectedWatcher: watcherInactivity,
|
||||||
|
}
|
||||||
|
|
||||||
|
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list")
|
||||||
|
go im.Start(ctx, m.onInactive)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) removePeer(peerID string) {
|
||||||
|
cfg, ok := m.managedPeers[peerID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Log.Infof("removing lazy peer")
|
||||||
|
|
||||||
|
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok {
|
||||||
|
im.Stop()
|
||||||
|
delete(m.inactivityMonitors, cfg.PeerConnID)
|
||||||
|
cfg.Log.Debugf("inactivity monitor stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
|
||||||
|
delete(m.managedPeers, peerID)
|
||||||
|
delete(m.managedPeersByConnID, cfg.PeerConnID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) close() {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
m.cancel()
|
||||||
|
|
||||||
|
m.connStateDispatcher.RemoveListener(m.connStateListener)
|
||||||
|
m.activityManager.Close()
|
||||||
|
for _, iw := range m.inactivityMonitors {
|
||||||
|
iw.Stop()
|
||||||
|
}
|
||||||
|
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
|
||||||
|
m.managedPeers = make(map[string]*lazyconn.PeerConfig)
|
||||||
|
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
|
||||||
|
log.Infof("lazy connection manager closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("peer not found by conn id: %v", peerConnID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if mp.expectedWatcher != watcherActivity {
|
||||||
|
mp.peerCfg.Log.Warnf("ignore activity event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("detected peer activity")
|
||||||
|
|
||||||
|
mp.expectedWatcher = watcherInactivity
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("starting inactivity monitor")
|
||||||
|
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
|
||||||
|
|
||||||
|
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||||
|
if !ok {
|
||||||
|
log.Errorf("peer not found by id: %v", peerConnID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if mp.expectedWatcher != watcherInactivity {
|
||||||
|
mp.peerCfg.Log.Warnf("ignore inactivity event")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("connection timed out")
|
||||||
|
|
||||||
|
// this is blocking operation, potentially can be optimized
|
||||||
|
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("start activity monitor")
|
||||||
|
|
||||||
|
mp.expectedWatcher = watcherActivity
|
||||||
|
|
||||||
|
// just in case free up
|
||||||
|
m.inactivityMonitors[peerConnID].PauseTimer()
|
||||||
|
|
||||||
|
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
|
||||||
|
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if mp.expectedWatcher != watcherInactivity {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||||
|
if !ok {
|
||||||
|
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
|
||||||
|
iw.PauseTimer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
|
||||||
|
m.managedPeersMu.Lock()
|
||||||
|
defer m.managedPeersMu.Unlock()
|
||||||
|
|
||||||
|
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if mp.expectedWatcher != watcherInactivity {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
|
||||||
|
iw.ResetTimer()
|
||||||
|
}
|
16
client/internal/lazyconn/peercfg.go
Normal file
16
client/internal/lazyconn/peercfg.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package lazyconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerConfig struct {
|
||||||
|
PublicKey string
|
||||||
|
AllowedIPs []netip.Prefix
|
||||||
|
PeerConnID id.ConnID
|
||||||
|
Log *log.Entry
|
||||||
|
}
|
41
client/internal/lazyconn/support.go
Normal file
41
client/internal/lazyconn/support.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package lazyconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-version"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
minVersion = version.Must(version.NewVersion("0.45.0"))
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsSupported(agentVersion string) bool {
|
||||||
|
if agentVersion == "development" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// filter out versions like this: a6c5960, a7d5c522, d47be154
|
||||||
|
if !strings.Contains(agentVersion, ".") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedVersion := normalizeVersion(agentVersion)
|
||||||
|
inputVer, err := version.NewVersion(normalizedVersion)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputVer.GreaterThanOrEqual(minVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeVersion(version string) string {
|
||||||
|
// Remove prefixes like 'v' or 'a'
|
||||||
|
if len(version) > 0 && (version[0] == 'v' || version[0] == 'a') {
|
||||||
|
version = version[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any suffixes like '-dirty', '-dev', '-SNAPSHOT', etc.
|
||||||
|
parts := strings.Split(version, "-")
|
||||||
|
return parts[0]
|
||||||
|
}
|
31
client/internal/lazyconn/support_test.go
Normal file
31
client/internal/lazyconn/support_test.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package lazyconn
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestIsSupported(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
version string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"development", true},
|
||||||
|
{"0.45.0", true},
|
||||||
|
{"v0.45.0", true},
|
||||||
|
{"0.45.1", true},
|
||||||
|
{"0.45.1-SNAPSHOT-559e6731", true},
|
||||||
|
{"v0.45.1-dev", true},
|
||||||
|
{"a7d5c522", false},
|
||||||
|
{"0.9.6", false},
|
||||||
|
{"0.9.6-SNAPSHOT", false},
|
||||||
|
{"0.9.6-SNAPSHOT-2033650", false},
|
||||||
|
{"meta_wt_version", false},
|
||||||
|
{"v0.31.1-dev", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.version, func(t *testing.T) {
|
||||||
|
if got := IsSupported(tt.version); got != tt.want {
|
||||||
|
t.Errorf("IsSupported() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
14
client/internal/lazyconn/wgiface.go
Normal file
14
client/internal/lazyconn/wgiface.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package lazyconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WGIface interface {
|
||||||
|
RemovePeer(peerKey string) error
|
||||||
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
}
|
@ -19,7 +19,7 @@ import (
|
|||||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
return fmt.Errorf("open routing socket: %v", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
err := unix.Close(fd)
|
err := unix.Close(fd)
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create route monitor: %w", err)
|
return fmt.Errorf("create route monitor: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := routeMonitor.Stop(); err != nil {
|
if err := routeMonitor.Stop(); err != nil {
|
||||||
@ -38,35 +38,49 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
|
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
|
||||||
intf := "<nil>"
|
if intf := route.NextHop.Intf; intf != nil && isSoftInterface(intf.Name) {
|
||||||
if route.Interface != nil {
|
log.Debugf("Network monitor: ignoring default route change for next hop with soft interface %s", route.NextHop)
|
||||||
intf = route.Interface.Name
|
return false
|
||||||
if isSoftInterface(intf) {
|
}
|
||||||
log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
|
|
||||||
return false
|
// TODO: for the empty nexthop ip (on-link), determine the family differently
|
||||||
}
|
nexthop := nexthopv4
|
||||||
|
if route.NextHop.IP.Is6() {
|
||||||
|
nexthop = nexthopv6
|
||||||
}
|
}
|
||||||
|
|
||||||
switch route.Type {
|
switch route.Type {
|
||||||
case systemops.RouteModified:
|
case systemops.RouteModified, systemops.RouteAdded:
|
||||||
// TODO: get routing table to figure out if our route is affected for modified routes
|
return handleRouteAddedOrModified(route, nexthop)
|
||||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
|
|
||||||
return true
|
|
||||||
case systemops.RouteAdded:
|
|
||||||
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
|
|
||||||
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
case systemops.RouteDeleted:
|
case systemops.RouteDeleted:
|
||||||
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
|
return handleRouteDeleted(route, nexthop)
|
||||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleRouteAddedOrModified(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool {
|
||||||
|
// For added/modified routes, we care about different next hops
|
||||||
|
if !nexthop.Equal(route.NextHop) {
|
||||||
|
action := "changed"
|
||||||
|
if route.Type == systemops.RouteAdded {
|
||||||
|
action = "added"
|
||||||
|
}
|
||||||
|
log.Infof("Network monitor: default route %s: via %s", action, route.NextHop)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleRouteDeleted(route systemops.RouteUpdate, nexthop systemops.Nexthop) bool {
|
||||||
|
// For deleted routes, we care about our tracked next hop being deleted
|
||||||
|
if nexthop.Equal(route.NextHop) {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s", route.NextHop)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func isSoftInterface(name string) bool {
|
func isSoftInterface(name string) bool {
|
||||||
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
|
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
|
||||||
}
|
}
|
||||||
|
404
client/internal/networkmonitor/check_change_windows_test.go
Normal file
404
client/internal/networkmonitor/check_change_windows_test.go
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRouteChanged(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
route systemops.RouteUpdate
|
||||||
|
nexthopv4 systemops.Nexthop
|
||||||
|
nexthopv6 systemops.Nexthop
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "soft interface should be ignored",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Name: "ISATAP-Interface", // isSoftInterface checks name
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.2"),
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified route with different v4 nexthop IP should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.2"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified route with same v4 nexthop (IP and Intf Index) should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "added route with different v6 nexthop IP should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteAdded,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::2"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "added route with same v6 nexthop (IP and Intf Index) should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteAdded,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted route matching tracked v4 nexthop (IP and Intf Index) should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteDeleted,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted route not matching tracked v4 nexthop (different IP) should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteDeleted,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.3"), // Different IP
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{
|
||||||
|
Index: 1, Name: "eth0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified v4 route with same IP, different Intf Index should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified v4 route with same IP, one Intf nil, other non-nil should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: nil, // Intf is nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"}, // Tracked Intf is not nil
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "added v4 route with same IP, different Intf Index should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteAdded,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted v4 route with same IP, different Intf Index should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteDeleted,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{ // This is the route being deleted
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{ // This is our tracked nexthop
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
|
||||||
|
},
|
||||||
|
expected: false, // Because nexthopv4.Equal(route.NextHop) will be false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified v6 route with different IP, same Intf Index should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::3"), // Different IP
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified v6 route with same IP, different Intf Index should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modified v6 route with same IP, same Intf Index should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteModified,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted v6 route matching tracked nexthop (IP and Intf Index) should return true",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteDeleted,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted v6 route not matching tracked nexthop (different IP) should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteDeleted,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::3"), // Different IP
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deleted v6 route not matching tracked nexthop (same IP, different Intf Index) should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteDeleted,
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{ // This is the route being deleted
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv6: systemops.Nexthop{ // This is our tracked nexthop
|
||||||
|
IP: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
Intf: &net.Interface{Index: 2, Name: "eth1"}, // Different Intf Index
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown route type should return false",
|
||||||
|
route: systemops.RouteUpdate{
|
||||||
|
Type: systemops.RouteUpdateType(99), // Unknown type
|
||||||
|
Destination: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
NextHop: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nexthopv4: systemops.Nexthop{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.2"), // Different from route.NextHop
|
||||||
|
Intf: &net.Interface{Index: 1, Name: "eth0"},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := routeChanged(tt.route, tt.nexthopv4, tt.nexthopv6)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsSoftInterface(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ifname string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ISATAP interface should be detected",
|
||||||
|
ifname: "ISATAP tunnel adapter",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase soft interface should be detected",
|
||||||
|
ifname: "isatap.{14A5CF17-CA72-43EC-B4EA-B4B093641B7D}",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Teredo interface should be detected",
|
||||||
|
ifname: "Teredo Tunneling Pseudo-Interface",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular interface should not be detected as soft",
|
||||||
|
ifname: "eth0",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "another regular interface should not be detected as soft",
|
||||||
|
ifname: "wlan0",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isSoftInterface(tt.ifname)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -118,9 +118,12 @@ func (nw *NetworkMonitor) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
|
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
|
||||||
|
defer close(event)
|
||||||
for {
|
for {
|
||||||
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
|
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
|
||||||
close(event)
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("Network monitor: failed to check for changes: %v", err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// prevent blocking
|
// prevent blocking
|
||||||
|
@ -17,8 +17,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -26,32 +30,20 @@ import (
|
|||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnPriority int
|
|
||||||
|
|
||||||
func (cp ConnPriority) String() string {
|
|
||||||
switch cp {
|
|
||||||
case connPriorityNone:
|
|
||||||
return "None"
|
|
||||||
case connPriorityRelay:
|
|
||||||
return "PriorityRelay"
|
|
||||||
case connPriorityICETurn:
|
|
||||||
return "PriorityICETurn"
|
|
||||||
case connPriorityICEP2P:
|
|
||||||
return "PriorityICEP2P"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("ConnPriority(%d)", cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
|
|
||||||
connPriorityNone ConnPriority = 0
|
|
||||||
connPriorityRelay ConnPriority = 1
|
|
||||||
connPriorityICETurn ConnPriority = 2
|
|
||||||
connPriorityICEP2P ConnPriority = 3
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ServiceDependencies struct {
|
||||||
|
StatusRecorder *Status
|
||||||
|
Signaler *Signaler
|
||||||
|
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
RelayManager *relayClient.Manager
|
||||||
|
SrWatcher *guard.SRWatcher
|
||||||
|
Semaphore *semaphoregroup.SemaphoreGroup
|
||||||
|
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
WgListenPort int
|
WgListenPort int
|
||||||
RemoteKey string
|
RemoteKey string
|
||||||
@ -76,6 +68,8 @@ type ConnConfig struct {
|
|||||||
// LocalKey is a public key of a local peer
|
// LocalKey is a public key of a local peer
|
||||||
LocalKey string
|
LocalKey string
|
||||||
|
|
||||||
|
AgentVersion string
|
||||||
|
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
|
||||||
WgConfig WgConfig
|
WgConfig WgConfig
|
||||||
@ -89,22 +83,23 @@ type ConnConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
log *log.Entry
|
Log *log.Entry
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
statusRecorder *Status
|
statusRecorder *Status
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
handshaker *Handshaker
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||||
onDisconnected func(remotePeer string)
|
onDisconnected func(remotePeer string)
|
||||||
|
|
||||||
statusRelay *AtomicConnStatus
|
statusRelay *worker.AtomicWorkerStatus
|
||||||
statusICE *AtomicConnStatus
|
statusICE *worker.AtomicWorkerStatus
|
||||||
currentConnPriority ConnPriority
|
currentConnPriority conntype.ConnPriority
|
||||||
opened bool // this flag is used to prevent close in case of not opened connection
|
opened bool // this flag is used to prevent close in case of not opened connection
|
||||||
|
|
||||||
workerICE *WorkerICE
|
workerICE *WorkerICE
|
||||||
@ -120,9 +115,12 @@ type Conn struct {
|
|||||||
|
|
||||||
wgProxyICE wgproxy.Proxy
|
wgProxyICE wgproxy.Proxy
|
||||||
wgProxyRelay wgproxy.Proxy
|
wgProxyRelay wgproxy.Proxy
|
||||||
|
handshaker *Handshaker
|
||||||
|
|
||||||
guard *guard.Guard
|
guard *guard.Guard
|
||||||
semaphore *semaphoregroup.SemaphoreGroup
|
semaphore *semaphoregroup.SemaphoreGroup
|
||||||
|
peerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// debug purpose
|
// debug purpose
|
||||||
dumpState *stateDump
|
dumpState *stateDump
|
||||||
@ -130,91 +128,101 @@ type Conn struct {
|
|||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
|
func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||||
if len(config.WgConfig.AllowedIps) == 0 {
|
if len(config.WgConfig.AllowedIps) == 0 {
|
||||||
return nil, fmt.Errorf("allowed IPs is empty")
|
return nil, fmt.Errorf("allowed IPs is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, ctxCancel := context.WithCancel(engineCtx)
|
|
||||||
connLog := log.WithField("peer", config.Key)
|
connLog := log.WithField("peer", config.Key)
|
||||||
|
|
||||||
var conn = &Conn{
|
var conn = &Conn{
|
||||||
log: connLog,
|
Log: connLog,
|
||||||
ctx: ctx,
|
config: config,
|
||||||
ctxCancel: ctxCancel,
|
statusRecorder: services.StatusRecorder,
|
||||||
config: config,
|
signaler: services.Signaler,
|
||||||
statusRecorder: statusRecorder,
|
iFaceDiscover: services.IFaceDiscover,
|
||||||
signaler: signaler,
|
relayManager: services.RelayManager,
|
||||||
relayManager: relayManager,
|
srWatcher: services.SrWatcher,
|
||||||
statusRelay: NewAtomicConnStatus(),
|
semaphore: services.Semaphore,
|
||||||
statusICE: NewAtomicConnStatus(),
|
peerConnDispatcher: services.PeerConnDispatcher,
|
||||||
semaphore: semaphore,
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
dumpState: newStateDump(config.Key, connLog, statusRecorder),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
|
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrl := isController(config)
|
|
||||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
|
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
|
||||||
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn.workerICE = workerICE
|
|
||||||
|
|
||||||
conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
|
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
|
||||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher)
|
|
||||||
|
|
||||||
go conn.handshaker.Listen()
|
|
||||||
|
|
||||||
go conn.dumpState.Start(ctx)
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open opens connection to the remote peer
|
// Open opens connection to the remote peer
|
||||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||||
// be used.
|
// be used.
|
||||||
func (conn *Conn) Open() {
|
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||||
conn.semaphore.Add(conn.ctx)
|
conn.semaphore.Add(engineCtx)
|
||||||
conn.log.Debugf("open connection to peer")
|
|
||||||
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
conn.opened = true
|
|
||||||
|
if conn.opened {
|
||||||
|
conn.semaphore.Done(engineCtx)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
|
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||||
|
|
||||||
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
|
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn.workerICE = workerICE
|
||||||
|
|
||||||
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||||
|
|
||||||
|
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||||
|
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
||||||
|
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||||
|
|
||||||
|
conn.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer conn.wg.Done()
|
||||||
|
conn.handshaker.Listen(conn.ctx)
|
||||||
|
}()
|
||||||
|
go conn.dumpState.Start(conn.ctx)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().String(),
|
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
ConnStatus: StatusDisconnected,
|
ConnStatus: StatusConnecting,
|
||||||
Mux: new(sync.RWMutex),
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
err := conn.statusRecorder.UpdatePeerState(peerState)
|
if err := conn.statusRecorder.UpdatePeerState(peerState); err != nil {
|
||||||
if err != nil {
|
conn.Log.Warnf("error while updating the state err: %v", err)
|
||||||
conn.log.Warnf("error while updating the state err: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go conn.startHandshakeAndReconnect(conn.ctx)
|
conn.wg.Add(1)
|
||||||
}
|
go func() {
|
||||||
|
defer conn.wg.Done()
|
||||||
|
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||||
|
conn.semaphore.Done(conn.ctx)
|
||||||
|
|
||||||
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
conn.dumpState.SendOffer()
|
||||||
defer conn.semaphore.Done(conn.ctx)
|
if err := conn.handshaker.sendOffer(); err != nil {
|
||||||
conn.waitInitialRandomSleepTime(ctx)
|
conn.Log.Errorf("failed to send initial offer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
conn.dumpState.SendOffer()
|
conn.wg.Add(1)
|
||||||
err := conn.handshaker.sendOffer()
|
go func() {
|
||||||
if err != nil {
|
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
conn.wg.Done()
|
||||||
}
|
}()
|
||||||
|
}()
|
||||||
go conn.guard.Start(ctx)
|
conn.opened = true
|
||||||
go conn.listenGuardEvent(ctx)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||||
@ -223,14 +231,14 @@ func (conn *Conn) Close() {
|
|||||||
defer conn.wgWatcherWg.Wait()
|
defer conn.wgWatcherWg.Wait()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
conn.log.Infof("close peer connection")
|
|
||||||
conn.ctxCancel()
|
|
||||||
|
|
||||||
if !conn.opened {
|
if !conn.opened {
|
||||||
conn.log.Debugf("ignore close connection to peer")
|
conn.Log.Debugf("ignore close connection to peer")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("close peer connection")
|
||||||
|
conn.ctxCancel()
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
conn.workerRelay.CloseConn()
|
conn.workerRelay.CloseConn()
|
||||||
conn.workerICE.Close()
|
conn.workerICE.Close()
|
||||||
@ -238,7 +246,7 @@ func (conn *Conn) Close() {
|
|||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
err := conn.wgProxyRelay.CloseConn()
|
err := conn.wgProxyRelay.CloseConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to close wg proxy for relay: %v", err)
|
conn.Log.Errorf("failed to close wg proxy for relay: %v", err)
|
||||||
}
|
}
|
||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
@ -246,13 +254,13 @@ func (conn *Conn) Close() {
|
|||||||
if conn.wgProxyICE != nil {
|
if conn.wgProxyICE != nil {
|
||||||
err := conn.wgProxyICE.CloseConn()
|
err := conn.wgProxyICE.CloseConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to close wg proxy for ice: %v", err)
|
conn.Log.Errorf("failed to close wg proxy for ice: %v", err)
|
||||||
}
|
}
|
||||||
conn.wgProxyICE = nil
|
conn.wgProxyICE = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.removeWgPeer(); err != nil {
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.freeUpConnID()
|
conn.freeUpConnID()
|
||||||
@ -262,14 +270,16 @@ func (conn *Conn) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.setStatusToDisconnected()
|
conn.setStatusToDisconnected()
|
||||||
conn.log.Infof("peer connection has been closed")
|
conn.opened = false
|
||||||
|
conn.wg.Wait()
|
||||||
|
conn.Log.Infof("peer connection closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||||
// doesn't block, discards the message if connection wasn't ready
|
// doesn't block, discards the message if connection wasn't ready
|
||||||
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||||
conn.dumpState.RemoteAnswer()
|
conn.dumpState.RemoteAnswer()
|
||||||
conn.log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
||||||
return conn.handshaker.OnRemoteAnswer(answer)
|
return conn.handshaker.OnRemoteAnswer(answer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -298,7 +308,7 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
|||||||
|
|
||||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
||||||
conn.dumpState.RemoteOffer()
|
conn.dumpState.RemoteOffer()
|
||||||
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||||
return conn.handshaker.OnRemoteOffer(offer)
|
return conn.handshaker.OnRemoteOffer(offer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -307,19 +317,24 @@ func (conn *Conn) WgConfig() WgConfig {
|
|||||||
return conn.config.WgConfig
|
return conn.config.WgConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status returns current status of the Conn
|
// IsConnected unit tests only
|
||||||
func (conn *Conn) Status() ConnStatus {
|
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
|
||||||
|
func (conn *Conn) IsConnected() bool {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
return conn.evalStatus()
|
return conn.currentConnPriority != conntype.None
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) GetKey() string {
|
func (conn *Conn) GetKey() string {
|
||||||
return conn.config.Key
|
return conn.config.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) ConnID() id.ConnID {
|
||||||
|
return id.ConnID(conn)
|
||||||
|
}
|
||||||
|
|
||||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
|
func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConnInfo ICEConnInfo) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@ -327,21 +342,21 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
if remoteConnNil(conn.Log, iceConnInfo.RemoteConn) {
|
||||||
conn.log.Errorf("remote ICE connection is nil")
|
conn.Log.Errorf("remote ICE connection is nil")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
||||||
// todo consider to remove this check
|
// todo consider to remove this check
|
||||||
if conn.currentConnPriority > priority {
|
if conn.currentConnPriority > priority {
|
||||||
conn.log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||||
conn.statusICE.Set(StatusConnected)
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Infof("set ICE to active connection")
|
conn.Log.Infof("set ICE to active connection")
|
||||||
conn.dumpState.P2PConnected()
|
conn.dumpState.P2PConnected()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -353,7 +368,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
|||||||
conn.dumpState.NewLocalProxy()
|
conn.dumpState.NewLocalProxy()
|
||||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ep = wgProxy.EndpointAddr()
|
ep = wgProxy.EndpointAddr()
|
||||||
@ -369,7 +384,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
@ -388,10 +403,16 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
|
oldState := conn.currentConnPriority
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.Set(StatusConnected)
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
|
|
||||||
|
if oldState == conntype.None {
|
||||||
|
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onICEStateDisconnected() {
|
func (conn *Conn) onICEStateDisconnected() {
|
||||||
@ -402,22 +423,22 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Tracef("ICE connection state changed to disconnected")
|
conn.Log.Tracef("ICE connection state changed to disconnected")
|
||||||
|
|
||||||
if conn.wgProxyICE != nil {
|
if conn.wgProxyICE != nil {
|
||||||
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
if err := conn.wgProxyICE.CloseConn(); err != nil {
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
conn.dumpState.SwitchToRelay()
|
conn.dumpState.SwitchToRelay()
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.wgWatcherWg.Add(1)
|
conn.wgWatcherWg.Add(1)
|
||||||
@ -425,17 +446,18 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
defer conn.wgWatcherWg.Done()
|
defer conn.wgWatcherWg.Done()
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
}()
|
}()
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = conntype.Relay
|
||||||
} else {
|
} else {
|
||||||
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
conn.currentConnPriority = connPriorityNone
|
conn.currentConnPriority = conntype.None
|
||||||
|
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != StatusDisconnected
|
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||||
if changed {
|
if changed {
|
||||||
conn.guard.SetICEConnDisconnected()
|
conn.guard.SetICEConnDisconnected()
|
||||||
}
|
}
|
||||||
conn.statusICE.Set(StatusDisconnected)
|
conn.statusICE.SetDisconnected()
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -446,7 +468,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -456,41 +478,41 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
if conn.ctx.Err() != nil {
|
if conn.ctx.Err() != nil {
|
||||||
if err := rci.relayedConn.Close(); err != nil {
|
if err := rci.relayedConn.Close(); err != nil {
|
||||||
conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
conn.Log.Warnf("failed to close unnecessary relayed connection: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.dumpState.RelayConnected()
|
conn.dumpState.RelayConnected()
|
||||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||||
|
|
||||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.dumpState.NewLocalProxy()
|
conn.dumpState.NewLocalProxy()
|
||||||
|
|
||||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
|
|
||||||
if conn.isICEActive() {
|
if conn.isICEActive() {
|
||||||
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||||
conn.setRelayedProxy(wgProxy)
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.SetConnected()
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||||
conn.log.Errorf("Before add peer hook failed: %v", err)
|
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -502,12 +524,13 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
|
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
conn.currentConnPriority = connPriorityRelay
|
conn.currentConnPriority = conntype.Relay
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.SetConnected()
|
||||||
conn.setRelayedProxy(wgProxy)
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
conn.log.Infof("start to communicate with peer via relay")
|
conn.Log.Infof("start to communicate with peer via relay")
|
||||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||||
|
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onRelayDisconnected() {
|
func (conn *Conn) onRelayDisconnected() {
|
||||||
@ -518,14 +541,15 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Infof("relay connection is disconnected")
|
conn.Log.Debugf("relay connection is disconnected")
|
||||||
|
|
||||||
if conn.currentConnPriority == connPriorityRelay {
|
if conn.currentConnPriority == conntype.Relay {
|
||||||
conn.log.Infof("clean up WireGuard config")
|
conn.Log.Debugf("clean up WireGuard config")
|
||||||
if err := conn.removeWgPeer(); err != nil {
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
conn.currentConnPriority = connPriorityNone
|
conn.currentConnPriority = conntype.None
|
||||||
|
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
@ -533,11 +557,11 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
conn.wgProxyRelay = nil
|
conn.wgProxyRelay = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusRelay.Get() != StatusDisconnected
|
changed := conn.statusRelay.Get() != worker.StatusDisconnected
|
||||||
if changed {
|
if changed {
|
||||||
conn.guard.SetRelayedConnDisconnected()
|
conn.guard.SetRelayedConnDisconnected()
|
||||||
}
|
}
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.SetDisconnected()
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@ -546,22 +570,15 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
}
|
}
|
||||||
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
|
||||||
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
conn.Log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
func (conn *Conn) onGuardEvent() {
|
||||||
for {
|
conn.Log.Debugf("send offer to peer")
|
||||||
select {
|
conn.dumpState.SendOffer()
|
||||||
case <-conn.guard.Reconnect:
|
if err := conn.handshaker.SendOffer(); err != nil {
|
||||||
conn.log.Infof("send offer to peer")
|
conn.Log.Errorf("failed to send offer: %v", err)
|
||||||
conn.dumpState.SendOffer()
|
|
||||||
if err := conn.handshaker.SendOffer(); err != nil {
|
|
||||||
conn.log.Errorf("failed to send offer: %v", err)
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -588,7 +605,7 @@ func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []by
|
|||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
|
err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Warnf("unable to save peer's Relay state, got error: %v", err)
|
conn.Log.Warnf("unable to save peer's Relay state, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -607,17 +624,18 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
|||||||
|
|
||||||
err := conn.statusRecorder.UpdatePeerICEState(peerState)
|
err := conn.statusRecorder.UpdatePeerICEState(peerState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Warnf("unable to save peer's ICE state, got error: %v", err)
|
conn.Log.Warnf("unable to save peer's ICE state, got error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) setStatusToDisconnected() {
|
func (conn *Conn) setStatusToDisconnected() {
|
||||||
conn.statusRelay.Set(StatusDisconnected)
|
conn.statusRelay.SetDisconnected()
|
||||||
conn.statusICE.Set(StatusDisconnected)
|
conn.statusICE.SetDisconnected()
|
||||||
|
conn.currentConnPriority = conntype.None
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
ConnStatus: StatusDisconnected,
|
ConnStatus: StatusIdle,
|
||||||
ConnStatusUpdate: time.Now(),
|
ConnStatusUpdate: time.Now(),
|
||||||
Mux: new(sync.RWMutex),
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
@ -625,10 +643,10 @@ func (conn *Conn) setStatusToDisconnected() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
// pretty common error because by that time Engine can already remove the peer and status won't be available.
|
||||||
// todo rethink status updates
|
// todo rethink status updates
|
||||||
conn.log.Debugf("error while updating peer's state, err: %v", err)
|
conn.Log.Debugf("error while updating peer's state, err: %v", err)
|
||||||
}
|
}
|
||||||
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
|
if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
|
||||||
conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
conn.Log.Debugf("failed to reset wireguard stats for peer: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -656,27 +674,20 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) isRelayed() bool {
|
func (conn *Conn) isRelayed() bool {
|
||||||
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
|
switch conn.currentConnPriority {
|
||||||
|
case conntype.Relay, conntype.ICETurn:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.currentConnPriority == connPriorityICEP2P {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) evalStatus() ConnStatus {
|
func (conn *Conn) evalStatus() ConnStatus {
|
||||||
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
|
if conn.statusRelay.Get() == worker.StatusConnected || conn.statusICE.Get() == worker.StatusConnected {
|
||||||
return StatusConnected
|
return StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
|
return StatusConnecting
|
||||||
return StatusConnecting
|
|
||||||
}
|
|
||||||
|
|
||||||
return StatusDisconnected
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||||
@ -689,12 +700,12 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if conn.statusICE.Get() == StatusDisconnected {
|
if conn.statusICE.Get() == worker.StatusDisconnected {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
if conn.statusRelay.Get() != StatusConnected {
|
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -716,7 +727,7 @@ func (conn *Conn) freeUpConnID() {
|
|||||||
if conn.connIDRelay != "" {
|
if conn.connIDRelay != "" {
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
if err := hook(conn.connIDRelay); err != nil {
|
if err := hook(conn.connIDRelay); err != nil {
|
||||||
conn.log.Errorf("After remove peer hook failed: %v", err)
|
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
conn.connIDRelay = ""
|
conn.connIDRelay = ""
|
||||||
@ -725,7 +736,7 @@ func (conn *Conn) freeUpConnID() {
|
|||||||
if conn.connIDICE != "" {
|
if conn.connIDICE != "" {
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
for _, hook := range conn.afterRemovePeerHooks {
|
||||||
if err := hook(conn.connIDICE); err != nil {
|
if err := hook(conn.connIDICE); err != nil {
|
||||||
conn.log.Errorf("After remove peer hook failed: %v", err)
|
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
conn.connIDICE = ""
|
conn.connIDICE = ""
|
||||||
@ -733,7 +744,7 @@ func (conn *Conn) freeUpConnID() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
conn.log.Debugf("setup proxied WireGuard connection")
|
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||||
udpAddr := &net.UDPAddr{
|
udpAddr := &net.UDPAddr{
|
||||||
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
|
IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(),
|
||||||
Port: conn.config.WgConfig.WgListenPort,
|
Port: conn.config.WgConfig.WgListenPort,
|
||||||
@ -741,18 +752,18 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
|||||||
|
|
||||||
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
|
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
|
||||||
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
|
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.Log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return wgProxy, nil
|
return wgProxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) isReadyToUpgrade() bool {
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) isICEActive() bool {
|
func (conn *Conn) isICEActive() bool {
|
||||||
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
|
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) removeWgPeer() error {
|
func (conn *Conn) removeWgPeer() error {
|
||||||
@ -760,10 +771,10 @@ func (conn *Conn) removeWgPeer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||||
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
|
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||||
if wgProxy != nil {
|
if wgProxy != nil {
|
||||||
if ierr := wgProxy.CloseConn(); ierr != nil {
|
if ierr := wgProxy.CloseConn(); ierr != nil {
|
||||||
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
|
conn.Log.Warnf("Failed to close wg proxy: %v", ierr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
@ -773,16 +784,16 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
|||||||
|
|
||||||
func (conn *Conn) logTraceConnState() {
|
func (conn *Conn) logTraceConnState() {
|
||||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
conn.Log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
|
||||||
} else {
|
} else {
|
||||||
conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
conn.Log.Tracef("connectivity guard check, ice state: %s", conn.statusICE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
conn.Log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
conn.wgProxyRelay = proxy
|
conn.wgProxyRelay = proxy
|
||||||
@ -793,6 +804,10 @@ func (conn *Conn) AllowedIP() netip.Addr {
|
|||||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) AgentVersionString() string {
|
||||||
|
return conn.config.AgentVersion
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||||
if conn.config.RosenpassConfig.PubKey == nil {
|
if conn.config.RosenpassConfig.PubKey == nil {
|
||||||
return conn.config.WgConfig.PreSharedKey
|
return conn.config.WgConfig.PreSharedKey
|
||||||
@ -804,7 +819,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
|||||||
|
|
||||||
determKey, err := conn.rosenpassDetermKey()
|
determKey, err := conn.rosenpassDetermKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||||
return conn.config.WgConfig.PreSharedKey
|
return conn.config.WgConfig.PreSharedKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,58 +1,29 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// StatusConnected indicate the peer is in connected state
|
// StatusIdle indicate the peer is in disconnected state
|
||||||
StatusConnected ConnStatus = iota
|
StatusIdle ConnStatus = iota
|
||||||
// StatusConnecting indicate the peer is in connecting state
|
// StatusConnecting indicate the peer is in connecting state
|
||||||
StatusConnecting
|
StatusConnecting
|
||||||
// StatusDisconnected indicate the peer is in disconnected state
|
// StatusConnected indicate the peer is in connected state
|
||||||
StatusDisconnected
|
StatusConnected
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnStatus describe the status of a peer's connection
|
// ConnStatus describe the status of a peer's connection
|
||||||
type ConnStatus int32
|
type ConnStatus int32
|
||||||
|
|
||||||
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
|
|
||||||
type AtomicConnStatus struct {
|
|
||||||
status atomic.Int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
|
|
||||||
func NewAtomicConnStatus() *AtomicConnStatus {
|
|
||||||
acs := &AtomicConnStatus{}
|
|
||||||
acs.Set(StatusDisconnected)
|
|
||||||
return acs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the current connection status
|
|
||||||
func (acs *AtomicConnStatus) Get() ConnStatus {
|
|
||||||
return ConnStatus(acs.status.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set updates the connection status
|
|
||||||
func (acs *AtomicConnStatus) Set(status ConnStatus) {
|
|
||||||
acs.status.Store(int32(status))
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the string representation of the current status
|
|
||||||
func (acs *AtomicConnStatus) String() string {
|
|
||||||
return acs.Get().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s ConnStatus) String() string {
|
func (s ConnStatus) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
case StatusConnecting:
|
case StatusConnecting:
|
||||||
return "Connecting"
|
return "Connecting"
|
||||||
case StatusConnected:
|
case StatusConnected:
|
||||||
return "Connected"
|
return "Connected"
|
||||||
case StatusDisconnected:
|
case StatusIdle:
|
||||||
return "Disconnected"
|
return "Idle"
|
||||||
default:
|
default:
|
||||||
log.Errorf("unknown status: %d", s)
|
log.Errorf("unknown status: %d", s)
|
||||||
return "INVALID_PEER_CONNECTION_STATUS"
|
return "INVALID_PEER_CONNECTION_STATUS"
|
||||||
|
@ -14,7 +14,7 @@ func TestConnStatus_String(t *testing.T) {
|
|||||||
want string
|
want string
|
||||||
}{
|
}{
|
||||||
{"StatusConnected", StatusConnected, "Connected"},
|
{"StatusConnected", StatusConnected, "Connected"},
|
||||||
{"StatusDisconnected", StatusDisconnected, "Disconnected"},
|
{"StatusIdle", StatusIdle, "Idle"},
|
||||||
{"StatusConnecting", StatusConnecting, "Connecting"},
|
{"StatusConnecting", StatusConnecting, "Connecting"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,5 +24,4 @@ func TestConnStatus_String(t *testing.T) {
|
|||||||
assert.Equal(t, got, table.want, "they should be equal")
|
assert.Equal(t, got, table.want, "they should be equal")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@ -11,6 +10,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@ -18,6 +18,8 @@ import (
|
|||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||||
|
|
||||||
var connConf = ConnConfig{
|
var connConf = ConnConfig{
|
||||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||||
@ -48,7 +50,13 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
|
||||||
|
sd := ServiceDependencies{
|
||||||
|
SrWatcher: swWatcher,
|
||||||
|
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||||
|
PeerConnDispatcher: testDispatcher,
|
||||||
|
}
|
||||||
|
conn, err := NewConn(connConf, sd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -60,7 +68,13 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
sd := ServiceDependencies{
|
||||||
|
StatusRecorder: NewRecorder("https://mgm"),
|
||||||
|
SrWatcher: swWatcher,
|
||||||
|
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||||
|
PeerConnDispatcher: testDispatcher,
|
||||||
|
}
|
||||||
|
conn, err := NewConn(connConf, sd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -94,7 +108,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
sd := ServiceDependencies{
|
||||||
|
StatusRecorder: NewRecorder("https://mgm"),
|
||||||
|
SrWatcher: swWatcher,
|
||||||
|
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||||
|
PeerConnDispatcher: testDispatcher,
|
||||||
|
}
|
||||||
|
conn, err := NewConn(connConf, sd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -125,43 +145,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tables := []struct {
|
|
||||||
name string
|
|
||||||
statusIce ConnStatus
|
|
||||||
statusRelay ConnStatus
|
|
||||||
want ConnStatus
|
|
||||||
}{
|
|
||||||
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
|
|
||||||
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
|
|
||||||
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
|
|
||||||
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
|
|
||||||
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
|
|
||||||
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
|
|
||||||
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range tables {
|
|
||||||
t.Run(table.name, func(t *testing.T) {
|
|
||||||
si := NewAtomicConnStatus()
|
|
||||||
si.Set(table.statusIce)
|
|
||||||
conn.statusICE = si
|
|
||||||
|
|
||||||
sr := NewAtomicConnStatus()
|
|
||||||
sr.Set(table.statusRelay)
|
|
||||||
conn.statusRelay = sr
|
|
||||||
|
|
||||||
got := conn.Status()
|
|
||||||
assert.Equal(t, got, table.want, "they should be equal")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConn_presharedKey(t *testing.T) {
|
func TestConn_presharedKey(t *testing.T) {
|
||||||
conn1 := Conn{
|
conn1 := Conn{
|
||||||
|
29
client/internal/peer/conntype/priority.go
Normal file
29
client/internal/peer/conntype/priority.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package conntype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
None ConnPriority = 0
|
||||||
|
Relay ConnPriority = 1
|
||||||
|
ICETurn ConnPriority = 2
|
||||||
|
ICEP2P ConnPriority = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnPriority int
|
||||||
|
|
||||||
|
func (cp ConnPriority) String() string {
|
||||||
|
switch cp {
|
||||||
|
case None:
|
||||||
|
return "None"
|
||||||
|
case Relay:
|
||||||
|
return "PriorityRelay"
|
||||||
|
case ICETurn:
|
||||||
|
return "PriorityICETurn"
|
||||||
|
case ICEP2P:
|
||||||
|
return "PriorityICEP2P"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||||
|
}
|
||||||
|
}
|
52
client/internal/peer/dispatcher/dispatcher.go
Normal file
52
client/internal/peer/dispatcher/dispatcher.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package dispatcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionListener struct {
|
||||||
|
OnConnected func(peerID id.ConnID)
|
||||||
|
OnDisconnected func(peerID id.ConnID)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectionDispatcher struct {
|
||||||
|
listeners map[*ConnectionListener]struct{}
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionDispatcher() *ConnectionDispatcher {
|
||||||
|
return &ConnectionDispatcher{
|
||||||
|
listeners: make(map[*ConnectionListener]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectionDispatcher) AddListener(listener *ConnectionListener) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
e.listeners[listener] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectionDispatcher) RemoveListener(listener *ConnectionListener) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
delete(e.listeners, listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectionDispatcher) NotifyConnected(peerConnID id.ConnID) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
for listener := range e.listeners {
|
||||||
|
listener.OnConnected(peerConnID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnectionDispatcher) NotifyDisconnected(peerConnID id.ConnID) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
for listener := range e.listeners {
|
||||||
|
listener.OnDisconnected(peerConnID)
|
||||||
|
}
|
||||||
|
}
|
@ -8,10 +8,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
reconnectMaxElapsedTime = 30 * time.Minute
|
|
||||||
)
|
|
||||||
|
|
||||||
type isConnectedFunc func() bool
|
type isConnectedFunc func() bool
|
||||||
|
|
||||||
// Guard is responsible for the reconnection logic.
|
// Guard is responsible for the reconnection logic.
|
||||||
@ -25,7 +21,6 @@ type isConnectedFunc func() bool
|
|||||||
type Guard struct {
|
type Guard struct {
|
||||||
Reconnect chan struct{}
|
Reconnect chan struct{}
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
isController bool
|
|
||||||
isConnectedOnAllWay isConnectedFunc
|
isConnectedOnAllWay isConnectedFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
srWatcher *SRWatcher
|
srWatcher *SRWatcher
|
||||||
@ -33,11 +28,10 @@ type Guard struct {
|
|||||||
iCEConnDisconnected chan struct{}
|
iCEConnDisconnected chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||||
return &Guard{
|
return &Guard{
|
||||||
Reconnect: make(chan struct{}, 1),
|
Reconnect: make(chan struct{}, 1),
|
||||||
log: log,
|
log: log,
|
||||||
isController: isController,
|
|
||||||
isConnectedOnAllWay: isConnectedFn,
|
isConnectedOnAllWay: isConnectedFn,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
srWatcher: srWatcher,
|
srWatcher: srWatcher,
|
||||||
@ -46,12 +40,8 @@ func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) Start(ctx context.Context) {
|
func (g *Guard) Start(ctx context.Context, eventCallback func()) {
|
||||||
if g.isController {
|
g.reconnectLoopWithRetry(ctx, eventCallback)
|
||||||
g.reconnectLoopWithRetry(ctx)
|
|
||||||
} else {
|
|
||||||
g.listenForDisconnectEvents(ctx)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) SetRelayedConnDisconnected() {
|
func (g *Guard) SetRelayedConnDisconnected() {
|
||||||
@ -68,9 +58,9 @@ func (g *Guard) SetICEConnDisconnected() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconnectLoopWithRetry periodically check (max 30 min) the connection status.
|
// reconnectLoopWithRetry periodically check the connection status.
|
||||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||||
waitForInitialConnectionTry(ctx)
|
waitForInitialConnectionTry(ctx)
|
||||||
|
|
||||||
srReconnectedChan := g.srWatcher.NewListener()
|
srReconnectedChan := g.srWatcher.NewListener()
|
||||||
@ -93,7 +83,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !g.isConnectedOnAllWay() {
|
if !g.isConnectedOnAllWay() {
|
||||||
g.triggerOfferSending()
|
callback()
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-g.relayedConnDisconnected:
|
case <-g.relayedConnDisconnected:
|
||||||
@ -121,39 +111,12 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer
|
|
||||||
// when the connection is lost. It will try to establish a connection only once time if before the connection was established
|
|
||||||
// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not
|
|
||||||
// mean that to switch to it. We always force to use the higher priority connection.
|
|
||||||
func (g *Guard) listenForDisconnectEvents(ctx context.Context) {
|
|
||||||
srReconnectedChan := g.srWatcher.NewListener()
|
|
||||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
|
||||||
|
|
||||||
g.log.Infof("start listen for reconnect events...")
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-g.relayedConnDisconnected:
|
|
||||||
g.log.Debugf("Relay connection changed, triggering reconnect")
|
|
||||||
g.triggerOfferSending()
|
|
||||||
case <-g.iCEConnDisconnected:
|
|
||||||
g.log.Debugf("ICE state changed, try to send new offer")
|
|
||||||
g.triggerOfferSending()
|
|
||||||
case <-srReconnectedChan:
|
|
||||||
g.triggerOfferSending()
|
|
||||||
case <-ctx.Done():
|
|
||||||
g.log.Debugf("context is done, stop reconnect loop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 800 * time.Millisecond,
|
InitialInterval: 800 * time.Millisecond,
|
||||||
RandomizationFactor: 0.1,
|
RandomizationFactor: 0.1,
|
||||||
Multiplier: 2,
|
Multiplier: 2,
|
||||||
MaxInterval: g.timeout,
|
MaxInterval: g.timeout,
|
||||||
MaxElapsedTime: reconnectMaxElapsedTime,
|
|
||||||
Stop: backoff.Stop,
|
Stop: backoff.Stop,
|
||||||
Clock: backoff.SystemClock,
|
Clock: backoff.SystemClock,
|
||||||
}, ctx)
|
}, ctx)
|
||||||
@ -164,13 +127,6 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
|||||||
return ticker
|
return ticker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Guard) triggerOfferSending() {
|
|
||||||
select {
|
|
||||||
case g.Reconnect <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give chance to the peer to establish the initial connection.
|
// Give chance to the peer to establish the initial connection.
|
||||||
// With it, we can decrease to send necessary offer
|
// With it, we can decrease to send necessary offer
|
||||||
func waitForInitialConnectionTry(ctx context.Context) {
|
func waitForInitialConnectionTry(ctx context.Context) {
|
||||||
|
@ -43,7 +43,6 @@ type OfferAnswer struct {
|
|||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
ctx context.Context
|
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
@ -57,9 +56,8 @@ type Handshaker struct {
|
|||||||
remoteAnswerCh chan OfferAnswer
|
remoteAnswerCh chan OfferAnswer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
|
||||||
return &Handshaker{
|
return &Handshaker{
|
||||||
ctx: ctx,
|
|
||||||
log: log,
|
log: log,
|
||||||
config: config,
|
config: config,
|
||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
@ -74,10 +72,10 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
|
|||||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) Listen() {
|
func (h *Handshaker) Listen(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
h.log.Info("wait for remote offer confirmation")
|
h.log.Info("wait for remote offer confirmation")
|
||||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
|
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var connectionClosedError *ConnectionClosedError
|
var connectionClosedError *ConnectionClosedError
|
||||||
if errors.As(err, &connectionClosedError) {
|
if errors.As(err, &connectionClosedError) {
|
||||||
@ -127,7 +125,7 @@ func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
|
||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
// received confirmation from the remote peer -> ready to proceed
|
// received confirmation from the remote peer -> ready to proceed
|
||||||
@ -137,7 +135,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
|||||||
return &remoteOfferAnswer, nil
|
return &remoteOfferAnswer, nil
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
return &remoteOfferAnswer, nil
|
return &remoteOfferAnswer, nil
|
||||||
case <-h.ctx.Done():
|
case <-ctx.Done():
|
||||||
// closed externally
|
// closed externally
|
||||||
return nil, NewConnectionClosedError(h.config.Key)
|
return nil, NewConnectionClosedError(h.config.Key)
|
||||||
}
|
}
|
||||||
|
5
client/internal/peer/id/connid.go
Normal file
5
client/internal/peer/id/connid.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package id
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
type ConnID unsafe.Pointer
|
@ -15,7 +15,7 @@ import (
|
|||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
@ -135,14 +135,15 @@ type NSGroupState struct {
|
|||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
type FullStatus struct {
|
type FullStatus struct {
|
||||||
Peers []State
|
Peers []State
|
||||||
ManagementState ManagementState
|
ManagementState ManagementState
|
||||||
SignalState SignalState
|
SignalState SignalState
|
||||||
LocalPeerState LocalPeerState
|
LocalPeerState LocalPeerState
|
||||||
RosenpassState RosenpassState
|
RosenpassState RosenpassState
|
||||||
Relays []relay.ProbeResult
|
Relays []relay.ProbeResult
|
||||||
NSGroupStates []NSGroupState
|
NSGroupStates []NSGroupState
|
||||||
NumOfForwardingRules int
|
NumOfForwardingRules int
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status holds a state of peers, signal, management connections and relays
|
// Status holds a state of peers, signal, management connections and relays
|
||||||
@ -164,6 +165,7 @@ type Status struct {
|
|||||||
rosenpassPermissive bool
|
rosenpassPermissive bool
|
||||||
nsGroupStates []NSGroupState
|
nsGroupStates []NSGroupState
|
||||||
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
|
||||||
|
lazyConnectionEnabled bool
|
||||||
|
|
||||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||||
@ -219,7 +221,7 @@ func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddPeer adds peer to Daemon status map
|
// AddPeer adds peer to Daemon status map
|
||||||
func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
@ -229,7 +231,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string) error {
|
|||||||
}
|
}
|
||||||
d.peers[peerPubKey] = State{
|
d.peers[peerPubKey] = State{
|
||||||
PubKey: peerPubKey,
|
PubKey: peerPubKey,
|
||||||
ConnStatus: StatusDisconnected,
|
IP: ip,
|
||||||
|
ConnStatus: StatusIdle,
|
||||||
FQDN: fqdn,
|
FQDN: fqdn,
|
||||||
Mux: new(sync.RWMutex),
|
Mux: new(sync.RWMutex),
|
||||||
}
|
}
|
||||||
@ -511,9 +514,9 @@ func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
|
|||||||
switch {
|
switch {
|
||||||
case receivedConnStatus == StatusConnecting:
|
case receivedConnStatus == StatusConnecting:
|
||||||
return true
|
return true
|
||||||
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
|
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusConnecting:
|
||||||
return true
|
return true
|
||||||
case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
|
case receivedConnStatus == StatusIdle && curr.ConnStatus == StatusIdle:
|
||||||
return curr.IP != ""
|
return curr.IP != ""
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
@ -689,6 +692,12 @@ func (d *Status) UpdateRosenpass(rosenpassEnabled, rosenpassPermissive bool) {
|
|||||||
d.rosenpassEnabled = rosenpassEnabled
|
d.rosenpassEnabled = rosenpassEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) UpdateLazyConnection(enabled bool) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.lazyConnectionEnabled = enabled
|
||||||
|
}
|
||||||
|
|
||||||
// MarkSignalDisconnected sets SignalState to disconnected
|
// MarkSignalDisconnected sets SignalState to disconnected
|
||||||
func (d *Status) MarkSignalDisconnected(err error) {
|
func (d *Status) MarkSignalDisconnected(err error) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@ -761,6 +770,12 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) GetLazyConnection() bool {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
return d.lazyConnectionEnabled
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Status) GetManagementState() ManagementState {
|
func (d *Status) GetManagementState() ManagementState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
@ -872,12 +887,13 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
|
|||||||
// GetFullStatus gets full status
|
// GetFullStatus gets full status
|
||||||
func (d *Status) GetFullStatus() FullStatus {
|
func (d *Status) GetFullStatus() FullStatus {
|
||||||
fullStatus := FullStatus{
|
fullStatus := FullStatus{
|
||||||
ManagementState: d.GetManagementState(),
|
ManagementState: d.GetManagementState(),
|
||||||
SignalState: d.GetSignalState(),
|
SignalState: d.GetSignalState(),
|
||||||
Relays: d.GetRelayStates(),
|
Relays: d.GetRelayStates(),
|
||||||
RosenpassState: d.GetRosenpassState(),
|
RosenpassState: d.GetRosenpassState(),
|
||||||
NSGroupStates: d.GetDNSStates(),
|
NSGroupStates: d.GetDNSStates(),
|
||||||
NumOfForwardingRules: len(d.ForwardingRules()),
|
NumOfForwardingRules: len(d.ForwardingRules()),
|
||||||
|
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
@ -10,22 +10,24 @@ import (
|
|||||||
|
|
||||||
func TestAddPeer(t *testing.T) {
|
func TestAddPeer(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
|
ip := "100.108.254.1"
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
err := status.AddPeer(key, "abc.netbird")
|
err := status.AddPeer(key, "abc.netbird", ip)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
_, exists := status.peers[key]
|
_, exists := status.peers[key]
|
||||||
assert.True(t, exists, "value was found")
|
assert.True(t, exists, "value was found")
|
||||||
|
|
||||||
err = status.AddPeer(key, "abc.netbird")
|
err = status.AddPeer(key, "abc.netbird", ip)
|
||||||
|
|
||||||
assert.Error(t, err, "should return error on duplicate")
|
assert.Error(t, err, "should return error on duplicate")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetPeer(t *testing.T) {
|
func TestGetPeer(t *testing.T) {
|
||||||
key := "abc"
|
key := "abc"
|
||||||
|
ip := "100.108.254.1"
|
||||||
status := NewRecorder("https://mgm")
|
status := NewRecorder("https://mgm")
|
||||||
err := status.AddPeer(key, "abc.netbird")
|
err := status.AddPeer(key, "abc.netbird", ip)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
peerStatus, err := status.GetPeer(key)
|
peerStatus, err := status.GetPeer(key)
|
||||||
|
@ -2,6 +2,7 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WGInterfaceStater interface {
|
type WGInterfaceStater interface {
|
||||||
GetStats(key string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type WGWatcher struct {
|
type WGWatcher struct {
|
||||||
@ -146,9 +147,13 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGWatcher) wgState() (time.Time, error) {
|
func (w *WGWatcher) wgState() (time.Time, error) {
|
||||||
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
|
wgStates, err := w.wgIfaceStater.GetStats()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Time{}, err
|
return time.Time{}, err
|
||||||
}
|
}
|
||||||
|
wgState, ok := wgStates[w.peerKey]
|
||||||
|
if !ok {
|
||||||
|
return time.Time{}, fmt.Errorf("peer %s not found in WireGuard endpoints", w.peerKey)
|
||||||
|
}
|
||||||
return wgState.LastHandshake, nil
|
return wgState.LastHandshake, nil
|
||||||
}
|
}
|
||||||
|
@ -11,26 +11,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MocWgIface struct {
|
type MocWgIface struct {
|
||||||
initial bool
|
stop bool
|
||||||
lastHandshake time.Time
|
|
||||||
stop bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
|
func (m *MocWgIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||||
if !m.initial {
|
return map[string]configurer.WGStats{}, nil
|
||||||
m.initial = true
|
|
||||||
return configurer.WGStats{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !m.stop {
|
|
||||||
m.lastHandshake = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
stats := configurer.WGStats{
|
|
||||||
LastHandshake: m.lastHandshake,
|
|
||||||
}
|
|
||||||
|
|
||||||
return stats, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MocWgIface) disconnect() {
|
func (m *MocWgIface) disconnect() {
|
||||||
|
55
client/internal/peer/worker/state.go
Normal file
55
client/internal/peer/worker/state.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package worker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusDisconnected Status = iota
|
||||||
|
StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
|
type Status int32
|
||||||
|
|
||||||
|
func (s Status) String() string {
|
||||||
|
switch s {
|
||||||
|
case StatusDisconnected:
|
||||||
|
return "Disconnected"
|
||||||
|
case StatusConnected:
|
||||||
|
return "Connected"
|
||||||
|
default:
|
||||||
|
log.Errorf("unknown status: %d", s)
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AtomicWorkerStatus is a thread-safe wrapper for worker status
|
||||||
|
type AtomicWorkerStatus struct {
|
||||||
|
status atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAtomicStatus() *AtomicWorkerStatus {
|
||||||
|
acs := &AtomicWorkerStatus{}
|
||||||
|
acs.SetDisconnected()
|
||||||
|
return acs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the current connection status
|
||||||
|
func (acs *AtomicWorkerStatus) Get() Status {
|
||||||
|
return Status(acs.status.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acs *AtomicWorkerStatus) SetConnected() {
|
||||||
|
acs.status.Store(int32(StatusConnected))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (acs *AtomicWorkerStatus) SetDisconnected() {
|
||||||
|
acs.status.Store(int32(StatusDisconnected))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the current status
|
||||||
|
func (acs *AtomicWorkerStatus) String() string {
|
||||||
|
return acs.Get().String()
|
||||||
|
}
|
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@ -397,10 +398,10 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
func selectedPriority(pair *ice.CandidatePair) conntype.ConnPriority {
|
||||||
if isRelayed(pair) {
|
if isRelayed(pair) {
|
||||||
return connPriorityICETurn
|
return conntype.ICETurn
|
||||||
} else {
|
} else {
|
||||||
return connPriorityICEP2P
|
return conntype.ICEP2P
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package peerstore
|
package peerstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -79,6 +80,32 @@ func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
|
|||||||
return p, true
|
return p, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// this can be blocked because of the connect open limiter semaphore
|
||||||
|
if err := p.Open(ctx); err != nil {
|
||||||
|
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) PeerConnClose(pubKey string) {
|
||||||
|
s.peerConnsMu.RLock()
|
||||||
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
|
||||||
|
p, ok := s.peerConns[pubKey]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) PeersPubKey() []string {
|
func (s *Store) PeersPubKey() []string {
|
||||||
s.peerConnsMu.RLock()
|
s.peerConnsMu.RLock()
|
||||||
defer s.peerConnsMu.RUnlock()
|
defer s.peerConnsMu.RUnlock()
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/management/client"
|
mgm "github.com/netbirdio/netbird/management/client"
|
||||||
|
"github.com/netbirdio/netbird/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
||||||
@ -41,6 +42,8 @@ type PKCEAuthProviderConfig struct {
|
|||||||
ClientCertPair *tls.Certificate
|
ClientCertPair *tls.Certificate
|
||||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||||
DisablePromptLogin bool
|
DisablePromptLogin bool
|
||||||
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
|
LoginFlag common.LoginFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||||
@ -100,6 +103,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
|
|||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||||
ClientCertPair: clientCert,
|
ClientCertPair: clientCert,
|
||||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
||||||
|
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@ -18,5 +17,4 @@ type wgIfaceBase interface {
|
|||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,10 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|||||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||||
nets := make([]string, 0)
|
nets := make([]string, 0)
|
||||||
for _, r := range clientRoutes {
|
for _, r := range clientRoutes {
|
||||||
|
// filter out domain routes
|
||||||
|
if r.IsDynamic() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
nets = append(nets, r.Network.String())
|
nets = append(nets, r.Network.String())
|
||||||
}
|
}
|
||||||
sort.Strings(nets)
|
sort.Strings(nets)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
@ -15,6 +16,20 @@ type Nexthop struct {
|
|||||||
Intf *net.Interface
|
Intf *net.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Equal checks if two nexthops are equal.
|
||||||
|
func (n Nexthop) Equal(other Nexthop) bool {
|
||||||
|
return n.IP == other.IP && (n.Intf == nil && other.Intf == nil ||
|
||||||
|
n.Intf != nil && other.Intf != nil && n.Intf.Index == other.Intf.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the nexthop.
|
||||||
|
func (n Nexthop) String() string {
|
||||||
|
if n.Intf == nil {
|
||||||
|
return n.IP.String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||||
|
|
||||||
type SysOps struct {
|
type SysOps struct {
|
||||||
|
@ -33,8 +33,7 @@ type RouteUpdateType int
|
|||||||
type RouteUpdate struct {
|
type RouteUpdate struct {
|
||||||
Type RouteUpdateType
|
Type RouteUpdateType
|
||||||
Destination netip.Prefix
|
Destination netip.Prefix
|
||||||
NextHop netip.Addr
|
NextHop Nexthop
|
||||||
Interface *net.Interface
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouteMonitor provides a way to monitor changes in the routing table.
|
// RouteMonitor provides a way to monitor changes in the routing table.
|
||||||
@ -231,15 +230,15 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
|
|||||||
intf, err := net.InterfaceByIndex(idx)
|
intf, err := net.InterfaceByIndex(idx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get interface name for index %d: %v", idx, err)
|
log.Warnf("failed to get interface name for index %d: %v", idx, err)
|
||||||
update.Interface = &net.Interface{
|
update.NextHop.Intf = &net.Interface{
|
||||||
Index: idx,
|
Index: idx,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
update.Interface = intf
|
update.NextHop.Intf = intf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)
|
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.NextHop.Intf)
|
||||||
dest := parseIPPrefix(row.DestinationPrefix, idx)
|
dest := parseIPPrefix(row.DestinationPrefix, idx)
|
||||||
if !dest.Addr().IsValid() {
|
if !dest.Addr().IsValid() {
|
||||||
return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row)
|
return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row)
|
||||||
@ -262,7 +261,7 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
|
|||||||
|
|
||||||
update.Type = updateType
|
update.Type = updateType
|
||||||
update.Destination = dest
|
update.Destination = dest
|
||||||
update.NextHop = nexthop
|
update.NextHop.IP = nexthop
|
||||||
|
|
||||||
return update, nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -94,7 +94,7 @@ message LoginRequest {
|
|||||||
|
|
||||||
bytes customDNSAddress = 7;
|
bytes customDNSAddress = 7;
|
||||||
|
|
||||||
bool isLinuxDesktopClient = 8;
|
bool isUnixDesktopClient = 8;
|
||||||
|
|
||||||
string hostname = 9;
|
string hostname = 9;
|
||||||
|
|
||||||
@ -134,6 +134,7 @@ message LoginRequest {
|
|||||||
// omits initialized empty slices due to omitempty tags
|
// omits initialized empty slices due to omitempty tags
|
||||||
bool cleanDNSLabels = 27;
|
bool cleanDNSLabels = 27;
|
||||||
|
|
||||||
|
optional bool lazyConnectionEnabled = 28;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@ -274,6 +275,8 @@ message FullStatus {
|
|||||||
int32 NumberOfForwardingRules = 8;
|
int32 NumberOfForwardingRules = 8;
|
||||||
|
|
||||||
repeated SystemEvent events = 7;
|
repeated SystemEvent events = 7;
|
||||||
|
|
||||||
|
bool lazyConnectionEnabled = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Networks
|
// Networks
|
||||||
|
@ -139,6 +139,7 @@ func (s *Server) Start() error {
|
|||||||
|
|
||||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||||
|
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||||
|
|
||||||
if s.sessionWatcher == nil {
|
if s.sessionWatcher == nil {
|
||||||
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
|
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
|
||||||
@ -417,6 +418,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
s.latestConfigInput.DisableNotifications = msg.DisableNotifications
|
s.latestConfigInput.DisableNotifications = msg.DisableNotifications
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.LazyConnectionEnabled != nil {
|
||||||
|
inputConfig.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||||
|
s.latestConfigInput.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||||
|
}
|
||||||
|
|
||||||
s.mutex.Unlock()
|
s.mutex.Unlock()
|
||||||
|
|
||||||
if msg.OptionalPreSharedKey != nil {
|
if msg.OptionalPreSharedKey != nil {
|
||||||
@ -446,7 +452,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
state.Set(internal.StatusConnecting)
|
state.Set(internal.StatusConnecting)
|
||||||
|
|
||||||
if msg.SetupKey == "" {
|
if msg.SetupKey == "" {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsLinuxDesktopClient)
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.Set(internal.StatusLoginFailed)
|
state.Set(internal.StatusLoginFailed)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -804,6 +810,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
|||||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||||
|
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||||
|
|
||||||
for _, peerState := range fullStatus.Peers {
|
for _, peerState := range fullStatus.Peers {
|
||||||
pbPeerState := &proto.PeerState{
|
pbPeerState := &proto.PeerState{
|
||||||
|
@ -97,6 +97,7 @@ type OutputOverview struct {
|
|||||||
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
|
NumberOfForwardingRules int `json:"forwardingRules" yaml:"forwardingRules"`
|
||||||
NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
NSServerGroups []NsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||||
Events []SystemEventOutput `json:"events" yaml:"events"`
|
Events []SystemEventOutput `json:"events" yaml:"events"`
|
||||||
|
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
|
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
|
||||||
@ -136,6 +137,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
|||||||
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
|
NumberOfForwardingRules: int(pbFullStatus.GetNumberOfForwardingRules()),
|
||||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||||
Events: mapEvents(pbFullStatus.GetEvents()),
|
Events: mapEvents(pbFullStatus.GetEvents()),
|
||||||
|
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if anon {
|
if anon {
|
||||||
@ -206,7 +208,7 @@ func mapPeers(
|
|||||||
transferSent := int64(0)
|
transferSent := int64(0)
|
||||||
|
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||||
if skipDetailByFilters(pbPeerState, isPeerConnected, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
|
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if isPeerConnected {
|
if isPeerConnected {
|
||||||
@ -384,6 +386,11 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lazyConnectionEnabledStatus := "false"
|
||||||
|
if overview.LazyConnectionEnabled {
|
||||||
|
lazyConnectionEnabledStatus = "true"
|
||||||
|
}
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||||
|
|
||||||
goos := runtime.GOOS
|
goos := runtime.GOOS
|
||||||
@ -405,6 +412,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
"NetBird IP: %s\n"+
|
"NetBird IP: %s\n"+
|
||||||
"Interface type: %s\n"+
|
"Interface type: %s\n"+
|
||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
|
"Lazy connection: %s\n"+
|
||||||
"Networks: %s\n"+
|
"Networks: %s\n"+
|
||||||
"Forwarding rules: %d\n"+
|
"Forwarding rules: %d\n"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
@ -419,6 +427,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
|||||||
interfaceIP,
|
interfaceIP,
|
||||||
interfaceTypeString,
|
interfaceTypeString,
|
||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
|
lazyConnectionEnabledStatus,
|
||||||
networks,
|
networks,
|
||||||
overview.NumberOfForwardingRules,
|
overview.NumberOfForwardingRules,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
@ -533,23 +542,13 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
|||||||
return peersString
|
return peersString
|
||||||
}
|
}
|
||||||
|
|
||||||
func skipDetailByFilters(
|
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
|
||||||
peerState *proto.PeerState,
|
|
||||||
isConnected bool,
|
|
||||||
statusFilter string,
|
|
||||||
prefixNamesFilter []string,
|
|
||||||
prefixNamesFilterMap map[string]struct{},
|
|
||||||
ipsFilter map[string]struct{},
|
|
||||||
) bool {
|
|
||||||
statusEval := false
|
statusEval := false
|
||||||
ipEval := false
|
ipEval := false
|
||||||
nameEval := true
|
nameEval := true
|
||||||
|
|
||||||
if statusFilter != "" {
|
if statusFilter != "" {
|
||||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
if !strings.EqualFold(peerStatus, statusFilter) {
|
||||||
if lowerStatusFilter == "disconnected" && isConnected {
|
|
||||||
statusEval = true
|
|
||||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
|
||||||
statusEval = true
|
statusEval = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -383,7 +383,8 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"error": "timeout"
|
"error": "timeout"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"events": []
|
"events": [],
|
||||||
|
"lazyConnectionEnabled": false
|
||||||
}`
|
}`
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
|
|
||||||
@ -484,6 +485,7 @@ dnsServers:
|
|||||||
enabled: false
|
enabled: false
|
||||||
error: timeout
|
error: timeout
|
||||||
events: []
|
events: []
|
||||||
|
lazyConnectionEnabled: false
|
||||||
`
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
@ -548,6 +550,7 @@ FQDN: some-localhost.awesome-domain.com
|
|||||||
NetBird IP: 192.168.178.100/16
|
NetBird IP: 192.168.178.100/16
|
||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
|
Lazy connection: false
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Forwarding rules: 0
|
Forwarding rules: 0
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
@ -570,6 +573,7 @@ FQDN: some-localhost.awesome-domain.com
|
|||||||
NetBird IP: 192.168.178.100/16
|
NetBird IP: 192.168.178.100/16
|
||||||
Interface type: Kernel
|
Interface type: Kernel
|
||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
|
Lazy connection: false
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Forwarding rules: 0
|
Forwarding rules: 0
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
|
@ -62,6 +62,8 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logFile = file
|
logFile = file
|
||||||
|
} else {
|
||||||
|
_ = util.InitLog("trace", "console")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the Fyne application.
|
// Create the Fyne application.
|
||||||
@ -192,6 +194,7 @@ type serviceClient struct {
|
|||||||
mAllowSSH *systray.MenuItem
|
mAllowSSH *systray.MenuItem
|
||||||
mAutoConnect *systray.MenuItem
|
mAutoConnect *systray.MenuItem
|
||||||
mEnableRosenpass *systray.MenuItem
|
mEnableRosenpass *systray.MenuItem
|
||||||
|
mLazyConnEnabled *systray.MenuItem
|
||||||
mNotifications *systray.MenuItem
|
mNotifications *systray.MenuItem
|
||||||
mAdvancedSettings *systray.MenuItem
|
mAdvancedSettings *systray.MenuItem
|
||||||
mCreateDebugBundle *systray.MenuItem
|
mCreateDebugBundle *systray.MenuItem
|
||||||
@ -383,12 +386,12 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
|||||||
s.adminURL = iAdminURL
|
s.adminURL = iAdminURL
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
ManagementUrl: iMngURL,
|
ManagementUrl: iMngURL,
|
||||||
AdminURL: iAdminURL,
|
AdminURL: iAdminURL,
|
||||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||||
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
|
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
|
||||||
InterfaceName: &s.iInterfaceName.Text,
|
InterfaceName: &s.iInterfaceName.Text,
|
||||||
WireguardPort: &port,
|
WireguardPort: &port,
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||||
@ -415,7 +418,7 @@ func (s *serviceClient) login() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
||||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("login to management URL with: %v", err)
|
log.Errorf("login to management URL with: %v", err)
|
||||||
@ -631,6 +634,7 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||||
|
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable lazy connection", lazyConnMenuDescr, false)
|
||||||
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
|
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
|
||||||
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
|
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", advancedSettingsMenuDescr)
|
||||||
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
|
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
|
||||||
@ -693,104 +697,114 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
|
|
||||||
go s.eventManager.Start(s.ctx)
|
go s.eventManager.Start(s.ctx)
|
||||||
|
|
||||||
go func() {
|
go s.listenEvents()
|
||||||
for {
|
}
|
||||||
select {
|
|
||||||
case <-s.mUp.ClickedCh:
|
|
||||||
s.mUp.Disable()
|
|
||||||
go func() {
|
|
||||||
defer s.mUp.Enable()
|
|
||||||
err := s.menuUpClick()
|
|
||||||
if err != nil {
|
|
||||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
case <-s.mDown.ClickedCh:
|
|
||||||
s.mDown.Disable()
|
|
||||||
go func() {
|
|
||||||
defer s.mDown.Enable()
|
|
||||||
err := s.menuDownClick()
|
|
||||||
if err != nil {
|
|
||||||
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
case <-s.mAllowSSH.ClickedCh:
|
|
||||||
if s.mAllowSSH.Checked() {
|
|
||||||
s.mAllowSSH.Uncheck()
|
|
||||||
} else {
|
|
||||||
s.mAllowSSH.Check()
|
|
||||||
}
|
|
||||||
if err := s.updateConfig(); err != nil {
|
|
||||||
log.Errorf("failed to update config: %v", err)
|
|
||||||
}
|
|
||||||
case <-s.mAutoConnect.ClickedCh:
|
|
||||||
if s.mAutoConnect.Checked() {
|
|
||||||
s.mAutoConnect.Uncheck()
|
|
||||||
} else {
|
|
||||||
s.mAutoConnect.Check()
|
|
||||||
}
|
|
||||||
if err := s.updateConfig(); err != nil {
|
|
||||||
log.Errorf("failed to update config: %v", err)
|
|
||||||
}
|
|
||||||
case <-s.mEnableRosenpass.ClickedCh:
|
|
||||||
if s.mEnableRosenpass.Checked() {
|
|
||||||
s.mEnableRosenpass.Uncheck()
|
|
||||||
} else {
|
|
||||||
s.mEnableRosenpass.Check()
|
|
||||||
}
|
|
||||||
if err := s.updateConfig(); err != nil {
|
|
||||||
log.Errorf("failed to update config: %v", err)
|
|
||||||
}
|
|
||||||
case <-s.mAdvancedSettings.ClickedCh:
|
|
||||||
s.mAdvancedSettings.Disable()
|
|
||||||
go func() {
|
|
||||||
defer s.mAdvancedSettings.Enable()
|
|
||||||
defer s.getSrvConfig()
|
|
||||||
s.runSelfCommand("settings", "true")
|
|
||||||
}()
|
|
||||||
case <-s.mCreateDebugBundle.ClickedCh:
|
|
||||||
s.mCreateDebugBundle.Disable()
|
|
||||||
go func() {
|
|
||||||
defer s.mCreateDebugBundle.Enable()
|
|
||||||
s.runSelfCommand("debug", "true")
|
|
||||||
}()
|
|
||||||
case <-s.mQuit.ClickedCh:
|
|
||||||
systray.Quit()
|
|
||||||
return
|
|
||||||
case <-s.mGitHub.ClickedCh:
|
|
||||||
err := openURL("https://github.com/netbirdio/netbird")
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("%s", err)
|
|
||||||
}
|
|
||||||
case <-s.mUpdate.ClickedCh:
|
|
||||||
err := openURL(version.DownloadUrl())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("%s", err)
|
|
||||||
}
|
|
||||||
case <-s.mNetworks.ClickedCh:
|
|
||||||
s.mNetworks.Disable()
|
|
||||||
go func() {
|
|
||||||
defer s.mNetworks.Enable()
|
|
||||||
s.runSelfCommand("networks", "true")
|
|
||||||
}()
|
|
||||||
case <-s.mNotifications.ClickedCh:
|
|
||||||
if s.mNotifications.Checked() {
|
|
||||||
s.mNotifications.Uncheck()
|
|
||||||
} else {
|
|
||||||
s.mNotifications.Check()
|
|
||||||
}
|
|
||||||
if s.eventManager != nil {
|
|
||||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
|
||||||
}
|
|
||||||
if err := s.updateConfig(); err != nil {
|
|
||||||
log.Errorf("failed to update config: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func (s *serviceClient) listenEvents() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.mUp.ClickedCh:
|
||||||
|
s.mUp.Disable()
|
||||||
|
go func() {
|
||||||
|
defer s.mUp.Enable()
|
||||||
|
err := s.menuUpClick()
|
||||||
|
if err != nil {
|
||||||
|
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
case <-s.mDown.ClickedCh:
|
||||||
|
s.mDown.Disable()
|
||||||
|
go func() {
|
||||||
|
defer s.mDown.Enable()
|
||||||
|
err := s.menuDownClick()
|
||||||
|
if err != nil {
|
||||||
|
s.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
case <-s.mAllowSSH.ClickedCh:
|
||||||
|
if s.mAllowSSH.Checked() {
|
||||||
|
s.mAllowSSH.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mAllowSSH.Check()
|
||||||
|
}
|
||||||
|
if err := s.updateConfig(); err != nil {
|
||||||
|
log.Errorf("failed to update config: %v", err)
|
||||||
|
}
|
||||||
|
case <-s.mAutoConnect.ClickedCh:
|
||||||
|
if s.mAutoConnect.Checked() {
|
||||||
|
s.mAutoConnect.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mAutoConnect.Check()
|
||||||
|
}
|
||||||
|
if err := s.updateConfig(); err != nil {
|
||||||
|
log.Errorf("failed to update config: %v", err)
|
||||||
|
}
|
||||||
|
case <-s.mEnableRosenpass.ClickedCh:
|
||||||
|
if s.mEnableRosenpass.Checked() {
|
||||||
|
s.mEnableRosenpass.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mEnableRosenpass.Check()
|
||||||
|
}
|
||||||
|
if err := s.updateConfig(); err != nil {
|
||||||
|
log.Errorf("failed to update config: %v", err)
|
||||||
|
}
|
||||||
|
case <-s.mLazyConnEnabled.ClickedCh:
|
||||||
|
if s.mLazyConnEnabled.Checked() {
|
||||||
|
s.mLazyConnEnabled.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mLazyConnEnabled.Check()
|
||||||
|
}
|
||||||
|
if err := s.updateConfig(); err != nil {
|
||||||
|
log.Errorf("failed to update config: %v", err)
|
||||||
|
}
|
||||||
|
case <-s.mAdvancedSettings.ClickedCh:
|
||||||
|
s.mAdvancedSettings.Disable()
|
||||||
|
go func() {
|
||||||
|
defer s.mAdvancedSettings.Enable()
|
||||||
|
defer s.getSrvConfig()
|
||||||
|
s.runSelfCommand("settings", "true")
|
||||||
|
}()
|
||||||
|
case <-s.mCreateDebugBundle.ClickedCh:
|
||||||
|
s.mCreateDebugBundle.Disable()
|
||||||
|
go func() {
|
||||||
|
defer s.mCreateDebugBundle.Enable()
|
||||||
|
s.runSelfCommand("debug", "true")
|
||||||
|
}()
|
||||||
|
case <-s.mQuit.ClickedCh:
|
||||||
|
systray.Quit()
|
||||||
|
return
|
||||||
|
case <-s.mGitHub.ClickedCh:
|
||||||
|
err := openURL("https://github.com/netbirdio/netbird")
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%s", err)
|
||||||
|
}
|
||||||
|
case <-s.mUpdate.ClickedCh:
|
||||||
|
err := openURL(version.DownloadUrl())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("%s", err)
|
||||||
|
}
|
||||||
|
case <-s.mNetworks.ClickedCh:
|
||||||
|
s.mNetworks.Disable()
|
||||||
|
go func() {
|
||||||
|
defer s.mNetworks.Enable()
|
||||||
|
s.runSelfCommand("networks", "true")
|
||||||
|
}()
|
||||||
|
case <-s.mNotifications.ClickedCh:
|
||||||
|
if s.mNotifications.Checked() {
|
||||||
|
s.mNotifications.Uncheck()
|
||||||
|
} else {
|
||||||
|
s.mNotifications.Check()
|
||||||
|
}
|
||||||
|
if s.eventManager != nil {
|
||||||
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
|
}
|
||||||
|
if err := s.updateConfig(); err != nil {
|
||||||
|
log.Errorf("failed to update config: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) runSelfCommand(command, arg string) {
|
func (s *serviceClient) runSelfCommand(command, arg string) {
|
||||||
@ -1022,13 +1036,15 @@ func (s *serviceClient) updateConfig() error {
|
|||||||
sshAllowed := s.mAllowSSH.Checked()
|
sshAllowed := s.mAllowSSH.Checked()
|
||||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||||
notificationsDisabled := !s.mNotifications.Checked()
|
notificationsDisabled := !s.mNotifications.Checked()
|
||||||
|
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||||
ServerSSHAllowed: &sshAllowed,
|
ServerSSHAllowed: &sshAllowed,
|
||||||
RosenpassEnabled: &rosenpassEnabled,
|
RosenpassEnabled: &rosenpassEnabled,
|
||||||
DisableAutoConnect: &disableAutoStart,
|
DisableAutoConnect: &disableAutoStart,
|
||||||
DisableNotifications: ¬ificationsDisabled,
|
DisableNotifications: ¬ificationsDisabled,
|
||||||
|
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.restartClient(&loginRequest); err != nil {
|
if err := s.restartClient(&loginRequest); err != nil {
|
||||||
|
@ -6,6 +6,7 @@ const (
|
|||||||
allowSSHMenuDescr = "Allow SSH connections"
|
allowSSHMenuDescr = "Allow SSH connections"
|
||||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||||
|
lazyConnMenuDescr = "[Experimental] Enable lazy connect"
|
||||||
notificationsMenuDescr = "Enable notifications"
|
notificationsMenuDescr = "Enable notifications"
|
||||||
advancedSettingsMenuDescr = "Advanced settings of the application"
|
advancedSettingsMenuDescr = "Advanced settings of the application"
|
||||||
debugBundleMenuDescr = "Create and open debug information bundle"
|
debugBundleMenuDescr = "Create and open debug information bundle"
|
||||||
|
@ -66,17 +66,17 @@ func (s SimpleRecord) String() string {
|
|||||||
func (s SimpleRecord) Len() uint16 {
|
func (s SimpleRecord) Len() uint16 {
|
||||||
emptyString := s.RData == ""
|
emptyString := s.RData == ""
|
||||||
switch s.Type {
|
switch s.Type {
|
||||||
case 1:
|
case int(dns.TypeA):
|
||||||
if emptyString {
|
if emptyString {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return net.IPv4len
|
return net.IPv4len
|
||||||
case 5:
|
case int(dns.TypeCNAME):
|
||||||
if emptyString || s.RData == "." {
|
if emptyString || s.RData == "." {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return uint16(len(s.RData) + 1)
|
return uint16(len(s.RData) + 1)
|
||||||
case 28:
|
case int(dns.TypeAAAA):
|
||||||
if emptyString {
|
if emptyString {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
4
go.mod
4
go.mod
@ -59,13 +59,12 @@ require (
|
|||||||
github.com/hashicorp/go-version v1.6.0
|
github.com/hashicorp/go-version v1.6.0
|
||||||
github.com/libdns/route53 v1.5.0
|
github.com/libdns/route53 v1.5.0
|
||||||
github.com/libp2p/go-netroute v0.2.1
|
github.com/libp2p/go-netroute v0.2.1
|
||||||
github.com/mattn/go-sqlite3 v1.14.22
|
|
||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||||
@ -195,6 +194,7 @@ require (
|
|||||||
github.com/libdns/libdns v0.2.2 // indirect
|
github.com/libdns/libdns v0.2.2 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||||
github.com/magiconair/properties v1.8.7 // indirect
|
github.com/magiconair/properties v1.8.7 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 // indirect
|
||||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
|
||||||
github.com/mholt/acmez/v2 v2.0.1 // indirect
|
github.com/mholt/acmez/v2 v2.0.1 // indirect
|
||||||
|
4
go.sum
4
go.sum
@ -507,8 +507,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-
|
|||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20250330143713-7901e0a82203/go.mod h1:2ZE6/tBBCKHQggPfO2UOQjyjXI7k+JDVl2ymorTOVQs=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
|
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
|
@ -59,6 +59,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
|||||||
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
|
||||||
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
|
||||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
|
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
|
||||||
|
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1}
|
||||||
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||||
|
|
||||||
# Dashboard
|
# Dashboard
|
||||||
@ -122,6 +123,7 @@ export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
|||||||
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
export NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
|
||||||
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
|
export NETBIRD_AUTH_PKCE_USE_ID_TOKEN
|
||||||
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
|
export NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
|
||||||
|
export NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
||||||
export NETBIRD_AUTH_PKCE_AUDIENCE
|
export NETBIRD_AUTH_PKCE_AUDIENCE
|
||||||
export NETBIRD_DASH_AUTH_USE_AUDIENCE
|
export NETBIRD_DASH_AUTH_USE_AUDIENCE
|
||||||
export NETBIRD_DASH_AUTH_AUDIENCE
|
export NETBIRD_DASH_AUTH_AUDIENCE
|
||||||
|
@ -95,7 +95,8 @@
|
|||||||
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
|
"Scope": "$NETBIRD_AUTH_SUPPORTED_SCOPES",
|
||||||
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
|
"RedirectURLs": [$NETBIRD_AUTH_PKCE_REDIRECT_URLS],
|
||||||
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
|
"UseIDToken": $NETBIRD_AUTH_PKCE_USE_ID_TOKEN,
|
||||||
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN
|
"DisablePromptLogin": $NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN,
|
||||||
|
"LoginFlag": $NETBIRD_AUTH_PKCE_LOGIN_FLAG
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -28,3 +28,4 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH
|
|||||||
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
NETBIRD_TURN_EXTERNAL_IP=1.2.3.4
|
||||||
NETBIRD_RELAY_PORT=33445
|
NETBIRD_RELAY_PORT=33445
|
||||||
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=true
|
||||||
|
NETBIRD_AUTH_PKCE_LOGIN_FLAG=0
|
||||||
|
19
management/client/common/types.go
Normal file
19
management/client/common/types.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
// LoginFlag introduces additional login flags to the PKCE authorization request
|
||||||
|
type LoginFlag uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LoginFlagPrompt adds prompt=login to the authorization request
|
||||||
|
LoginFlagPrompt LoginFlag = iota
|
||||||
|
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
|
||||||
|
LoginFlagMaxAge0
|
||||||
|
)
|
||||||
|
|
||||||
|
func (l LoginFlag) IsPromptLogin() bool {
|
||||||
|
return l == LoginFlagPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l LoginFlag) IsMaxAge0Login() bool {
|
||||||
|
return l == LoginFlagMaxAge0
|
||||||
|
}
|
@ -16,11 +16,13 @@ type AccountsAPI struct {
|
|||||||
// List list all accounts, only returns one account always
|
// List list all accounts, only returns one account always
|
||||||
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
|
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
|
||||||
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
|
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/accounts", nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[[]api.Account](resp)
|
ret, err := parseResponse[[]api.Account](resp)
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
@ -32,11 +34,13 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
|
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[api.Account](resp)
|
ret, err := parseResponse[api.Account](resp)
|
||||||
return &ret, err
|
return &ret, err
|
||||||
}
|
}
|
||||||
@ -44,11 +48,13 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.
|
|||||||
// Delete delete account
|
// Delete delete account
|
||||||
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
|
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
|
||||||
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
|
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
|
||||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
|
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
managementURL string
|
managementURL string
|
||||||
authHeader string
|
authHeader string
|
||||||
|
httpClient HttpClient
|
||||||
|
|
||||||
// Accounts NetBird account APIs
|
// Accounts NetBird account APIs
|
||||||
// see more: https://docs.netbird.io/api/resources/accounts
|
// see more: https://docs.netbird.io/api/resources/accounts
|
||||||
@ -70,20 +71,29 @@ type Client struct {
|
|||||||
|
|
||||||
// New initialize new Client instance using PAT token
|
// New initialize new Client instance using PAT token
|
||||||
func New(managementURL, token string) *Client {
|
func New(managementURL, token string) *Client {
|
||||||
client := &Client{
|
return NewWithOptions(
|
||||||
managementURL: managementURL,
|
WithManagementURL(managementURL),
|
||||||
authHeader: "Token " + token,
|
WithPAT(token),
|
||||||
}
|
)
|
||||||
client.initialize()
|
|
||||||
return client
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithBearerToken initialize new Client instance using Bearer token type
|
// NewWithBearerToken initialize new Client instance using Bearer token type
|
||||||
func NewWithBearerToken(managementURL, token string) *Client {
|
func NewWithBearerToken(managementURL, token string) *Client {
|
||||||
|
return NewWithOptions(
|
||||||
|
WithManagementURL(managementURL),
|
||||||
|
WithBearerToken(token),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithOptions(opts ...option) *Client {
|
||||||
client := &Client{
|
client := &Client{
|
||||||
managementURL: managementURL,
|
httpClient: http.DefaultClient,
|
||||||
authHeader: "Bearer " + token,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, option := range opts {
|
||||||
|
option(client)
|
||||||
|
}
|
||||||
|
|
||||||
client.initialize()
|
client.initialize()
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
@ -104,7 +114,7 @@ func (c *Client) initialize() {
|
|||||||
c.Events = &EventsAPI{c}
|
c.Events = &EventsAPI{c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) newRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
|
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -116,7 +126,7 @@ func (c *Client) newRequest(ctx context.Context, method, path string, body io.Re
|
|||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -16,11 +16,13 @@ type DNSAPI struct {
|
|||||||
// ListNameserverGroups list all nameserver groups
|
// ListNameserverGroups list all nameserver groups
|
||||||
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
|
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
|
||||||
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
|
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers", nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[[]api.NameserverGroup](resp)
|
ret, err := parseResponse[[]api.NameserverGroup](resp)
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
@ -28,11 +30,13 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou
|
|||||||
// GetNameserverGroup get nameserver group info
|
// GetNameserverGroup get nameserver group info
|
||||||
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
|
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
|
||||||
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
|
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||||
return &ret, err
|
return &ret, err
|
||||||
}
|
}
|
||||||
@ -44,11 +48,13 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := a.c.newRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
|
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||||
return &ret, err
|
return &ret, err
|
||||||
}
|
}
|
||||||
@ -60,11 +66,13 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
|
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||||
return &ret, err
|
return &ret, err
|
||||||
}
|
}
|
||||||
@ -72,11 +80,13 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st
|
|||||||
// DeleteNameserverGroup delete nameserver group
|
// DeleteNameserverGroup delete nameserver group
|
||||||
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
|
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
|
||||||
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
|
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
|
||||||
resp, err := a.c.newRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -84,11 +94,13 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st
|
|||||||
// GetSettings get DNS settings
|
// GetSettings get DNS settings
|
||||||
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
|
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
|
||||||
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
|
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/dns/settings", nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[api.DNSSettings](resp)
|
ret, err := parseResponse[api.DNSSettings](resp)
|
||||||
return &ret, err
|
return &ret, err
|
||||||
}
|
}
|
||||||
@ -100,11 +112,13 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp, err := a.c.newRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
|
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[api.DNSSettings](resp)
|
ret, err := parseResponse[api.DNSSettings](resp)
|
||||||
return &ret, err
|
return &ret, err
|
||||||
}
|
}
|
||||||
|
@ -14,11 +14,13 @@ type EventsAPI struct {
|
|||||||
// List list all events
|
// List list all events
|
||||||
// See more: https://docs.netbird.io/api/resources/events#list-all-events
|
// See more: https://docs.netbird.io/api/resources/events#list-all-events
|
||||||
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
|
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/events", nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[[]api.Event](resp)
|
ret, err := parseResponse[[]api.Event](resp)
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
|
@ -14,11 +14,13 @@ type GeoLocationAPI struct {
|
|||||||
// ListCountries list all country codes
|
// ListCountries list all country codes
|
||||||
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
|
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
|
||||||
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
|
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries", nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[[]api.Country](resp)
|
ret, err := parseResponse[[]api.Country](resp)
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
@ -26,11 +28,13 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro
|
|||||||
// ListCountryCities Get a list of all English city names for a given country code
|
// ListCountryCities Get a list of all English city names for a given country code
|
||||||
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
|
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
|
||||||
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
|
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
|
||||||
resp, err := a.c.newRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
|
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
if resp.Body != nil {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
}
|
||||||
ret, err := parseResponse[[]api.City](resp)
|
ret, err := parseResponse[[]api.City](resp)
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user