mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 09:33:24 +01:00
Monitor network changes and restart engine on detection (#1904)
This commit is contained in:
parent
2e0047daea
commit
920877964f
@ -32,6 +32,7 @@ const (
|
|||||||
preSharedKeyFlag = "preshared-key"
|
preSharedKeyFlag = "preshared-key"
|
||||||
interfaceNameFlag = "interface-name"
|
interfaceNameFlag = "interface-name"
|
||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
@ -62,6 +63,7 @@ var (
|
|||||||
serverSSHAllowed bool
|
serverSSHAllowed bool
|
||||||
interfaceName string
|
interfaceName string
|
||||||
wireguardPort uint16
|
wireguardPort uint16
|
||||||
|
networkMonitor bool
|
||||||
serviceName string
|
serviceName string
|
||||||
autoConnectDisabled bool
|
autoConnectDisabled bool
|
||||||
extraIFaceBlackList []string
|
extraIFaceBlackList []string
|
||||||
|
@ -40,6 +40,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||||
|
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,6 +117,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.WireguardPort = &p
|
ic.WireguardPort = &p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
ic.NetworkMonitor = &networkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
ic.PreSharedKey = &preSharedKey
|
ic.PreSharedKey = &preSharedKey
|
||||||
}
|
}
|
||||||
@ -226,6 +231,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.WireguardPort = &wp
|
loginRequest.WireguardPort = &wp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(networkMonitorFlag).Changed {
|
||||||
|
loginRequest.NetworkMonitor = &networkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
|
@ -48,6 +48,7 @@ type ConfigInput struct {
|
|||||||
RosenpassPermissive *bool
|
RosenpassPermissive *bool
|
||||||
InterfaceName *string
|
InterfaceName *string
|
||||||
WireguardPort *int
|
WireguardPort *int
|
||||||
|
NetworkMonitor *bool
|
||||||
DisableAutoConnect *bool
|
DisableAutoConnect *bool
|
||||||
ExtraIFaceBlackList []string
|
ExtraIFaceBlackList []string
|
||||||
}
|
}
|
||||||
@ -61,6 +62,7 @@ type Config struct {
|
|||||||
AdminURL *url.URL
|
AdminURL *url.URL
|
||||||
WgIface string
|
WgIface string
|
||||||
WgPort int
|
WgPort int
|
||||||
|
NetworkMonitor bool
|
||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
@ -188,6 +190,10 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config.WgPort = *input.WireguardPort
|
config.WgPort = *input.WireguardPort
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.NetworkMonitor != nil {
|
||||||
|
config.NetworkMonitor = *input.NetworkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
config.WgIface = iface.WgInterfaceDefault
|
config.WgIface = iface.WgInterfaceDefault
|
||||||
if input.InterfaceName != nil {
|
if input.InterfaceName != nil {
|
||||||
config.WgIface = *input.InterfaceName
|
config.WgIface = *input.InterfaceName
|
||||||
@ -279,6 +285,11 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
refresh = true
|
refresh = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.NetworkMonitor != nil {
|
||||||
|
config.NetworkMonitor = *input.NetworkMonitor
|
||||||
|
refresh = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.WireguardPort != nil {
|
if input.WireguardPort != nil {
|
||||||
config.WgPort = *input.WireguardPort
|
config.WgPort = *input.WireguardPort
|
||||||
refresh = true
|
refresh = true
|
||||||
|
@ -249,7 +249,7 @@ func runClient(
|
|||||||
engineChan <- engine
|
engineChan <- engine
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
@ -297,6 +297,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: config.WgPort,
|
WgPort: config.WgPort,
|
||||||
|
NetworkMonitor: config.NetworkMonitor,
|
||||||
SSHKey: []byte(config.SSHKey),
|
SSHKey: []byte(config.SSHKey),
|
||||||
NATExternalIPs: config.NATExternalIPs,
|
NATExternalIPs: config.NATExternalIPs,
|
||||||
CustomDNSAddress: config.CustomDNSAddress,
|
CustomDNSAddress: config.CustomDNSAddress,
|
||||||
|
@ -2,6 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
@ -21,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
@ -60,6 +62,9 @@ type EngineConfig struct {
|
|||||||
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
// WgPrivateKey is a Wireguard private key of our peer (it MUST never leave the machine)
|
||||||
WgPrivateKey wgtypes.Key
|
WgPrivateKey wgtypes.Key
|
||||||
|
|
||||||
|
// NetworkMonitor is a flag to enable network monitoring
|
||||||
|
NetworkMonitor bool
|
||||||
|
|
||||||
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
|
// IFaceBlackList is a list of network interfaces to ignore when discovering connection candidates (ICE related)
|
||||||
IFaceBlackList []string
|
IFaceBlackList []string
|
||||||
DisableIPv6Discovery bool
|
DisableIPv6Discovery bool
|
||||||
@ -114,9 +119,11 @@ type Engine struct {
|
|||||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||||
clientRoutes route.HAMap
|
clientRoutes route.HAMap
|
||||||
|
|
||||||
cancel context.CancelFunc
|
clientCtx context.Context
|
||||||
|
clientCancel context.CancelFunc
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgInterface *iface.WGIface
|
wgInterface *iface.WGIface
|
||||||
wgProxyFactory *wgproxy.Factory
|
wgProxyFactory *wgproxy.Factory
|
||||||
@ -126,6 +133,8 @@ type Engine struct {
|
|||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
|
|
||||||
|
networkWatcher *networkmonitor.NetworkWatcher
|
||||||
|
|
||||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||||
sshServer nbssh.Server
|
sshServer nbssh.Server
|
||||||
|
|
||||||
@ -151,8 +160,8 @@ type Peer struct {
|
|||||||
|
|
||||||
// NewEngine creates a new Connection Engine
|
// NewEngine creates a new Connection Engine
|
||||||
func NewEngine(
|
func NewEngine(
|
||||||
ctx context.Context,
|
clientCtx context.Context,
|
||||||
cancel context.CancelFunc,
|
clientCancel context.CancelFunc,
|
||||||
signalClient signal.Client,
|
signalClient signal.Client,
|
||||||
mgmClient mgm.Client,
|
mgmClient mgm.Client,
|
||||||
config *EngineConfig,
|
config *EngineConfig,
|
||||||
@ -160,8 +169,8 @@ func NewEngine(
|
|||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
return NewEngineWithProbes(
|
return NewEngineWithProbes(
|
||||||
ctx,
|
clientCtx,
|
||||||
cancel,
|
clientCancel,
|
||||||
signalClient,
|
signalClient,
|
||||||
mgmClient,
|
mgmClient,
|
||||||
config,
|
config,
|
||||||
@ -176,8 +185,8 @@ func NewEngine(
|
|||||||
|
|
||||||
// NewEngineWithProbes creates a new Connection Engine with probes attached
|
// NewEngineWithProbes creates a new Connection Engine with probes attached
|
||||||
func NewEngineWithProbes(
|
func NewEngineWithProbes(
|
||||||
ctx context.Context,
|
clientCtx context.Context,
|
||||||
cancel context.CancelFunc,
|
clientCancel context.CancelFunc,
|
||||||
signalClient signal.Client,
|
signalClient signal.Client,
|
||||||
mgmClient mgm.Client,
|
mgmClient mgm.Client,
|
||||||
config *EngineConfig,
|
config *EngineConfig,
|
||||||
@ -188,9 +197,10 @@ func NewEngineWithProbes(
|
|||||||
relayProbe *Probe,
|
relayProbe *Probe,
|
||||||
wgProbe *Probe,
|
wgProbe *Probe,
|
||||||
) *Engine {
|
) *Engine {
|
||||||
|
|
||||||
return &Engine{
|
return &Engine{
|
||||||
ctx: ctx,
|
clientCtx: clientCtx,
|
||||||
cancel: cancel,
|
clientCancel: clientCancel,
|
||||||
signal: signalClient,
|
signal: signalClient,
|
||||||
mgmClient: mgmClient,
|
mgmClient: mgmClient,
|
||||||
peerConns: make(map[string]*peer.Conn),
|
peerConns: make(map[string]*peer.Conn),
|
||||||
@ -202,7 +212,7 @@ func NewEngineWithProbes(
|
|||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
sshServerFunc: nbssh.DefaultSSHServer,
|
sshServerFunc: nbssh.DefaultSSHServer,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
networkWatcher: networkmonitor.New(),
|
||||||
mgmProbe: mgmProbe,
|
mgmProbe: mgmProbe,
|
||||||
signalProbe: signalProbe,
|
signalProbe: signalProbe,
|
||||||
relayProbe: relayProbe,
|
relayProbe: relayProbe,
|
||||||
@ -214,6 +224,13 @@ func (e *Engine) Stop() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if e.cancel != nil {
|
||||||
|
e.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopping network monitor first to avoid starting the engine again
|
||||||
|
e.networkWatcher.Stop()
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -222,7 +239,7 @@ func (e *Engine) Stop() error {
|
|||||||
e.clientRoutes = nil
|
e.clientRoutes = nil
|
||||||
|
|
||||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||||
// Removing peers happens in the conn.CLose() asynchronously
|
// Removing peers happens in the conn.Close() asynchronously
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
@ -237,6 +254,13 @@ func (e *Engine) Start() error {
|
|||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if e.cancel != nil {
|
||||||
|
e.cancel()
|
||||||
|
}
|
||||||
|
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||||
|
|
||||||
|
e.wgProxyFactory = wgproxy.NewFactory(e.clientCtx, e.config.WgPort)
|
||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||||
@ -320,6 +344,21 @@ func (e *Engine) Start() error {
|
|||||||
e.receiveManagementEvents()
|
e.receiveManagementEvents()
|
||||||
e.receiveProbeEvents()
|
e.receiveProbeEvents()
|
||||||
|
|
||||||
|
if e.config.NetworkMonitor {
|
||||||
|
// starting network monitor at the very last to avoid disruptions
|
||||||
|
go e.networkWatcher.Start(e.ctx, func() {
|
||||||
|
log.Infof("Network monitor detected network change, restarting engine")
|
||||||
|
if err := e.Stop(); err != nil {
|
||||||
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
|
}
|
||||||
|
if err := e.Start(); err != nil {
|
||||||
|
log.Errorf("Failed to start engine: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -588,12 +627,12 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
err := e.mgmClient.Sync(e.handleSync)
|
err := e.mgmClient.Sync(e.ctx, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
e.cancel()
|
e.clientCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("stopped receiving updates from Management Service")
|
log.Debugf("stopped receiving updates from Management Service")
|
||||||
@ -869,11 +908,12 @@ func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
|||||||
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
|
||||||
e.syncMsgMux.Unlock()
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
err := conn.Open()
|
err := conn.Open(e.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
log.Debugf("connection to peer %s failed: %v", peerKey, err)
|
||||||
switch err.(type) {
|
var connectionClosedError *peer.ConnectionClosedError
|
||||||
case *peer.ConnectionClosedError:
|
switch {
|
||||||
|
case errors.As(err, &connectionClosedError):
|
||||||
// conn has been forced to close, so we exit the loop
|
// conn has been forced to close, so we exit the loop
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
@ -984,7 +1024,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
func (e *Engine) receiveSignalEvents() {
|
func (e *Engine) receiveSignalEvents() {
|
||||||
go func() {
|
go func() {
|
||||||
// connect to a stream of messages coming from the signal server
|
// connect to a stream of messages coming from the signal server
|
||||||
err := e.signal.Receive(func(msg *sProto.Message) error {
|
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@ -1058,7 +1098,7 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
// happens if signal is unavailable for a long time.
|
// happens if signal is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
e.cancel()
|
e.clientCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -1119,13 +1159,16 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
if err := e.wgProxyFactory.Free(); err != nil {
|
if e.wgProxyFactory != nil {
|
||||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
if err := e.wgProxyFactory.Free(); err != nil {
|
||||||
|
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
if e.dnsServer != nil {
|
if e.dnsServer != nil {
|
||||||
e.dnsServer.Stop()
|
e.dnsServer.Stop()
|
||||||
|
e.dnsServer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
|
@ -392,7 +392,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
// feed updates to Engine via mocked Management client
|
// feed updates to Engine via mocked Management client
|
||||||
updates := make(chan *mgmtProto.SyncResponse)
|
updates := make(chan *mgmtProto.SyncResponse)
|
||||||
defer close(updates)
|
defer close(updates)
|
||||||
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||||
for msg := range updates {
|
for msg := range updates {
|
||||||
err := msgHandler(msg)
|
err := msgHandler(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1 +0,0 @@
|
|||||||
package internal
|
|
15
client/internal/networkmonitor/monitor.go
Normal file
15
client/internal/networkmonitor/monitor.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetworkWatcher watches for changes in network configuration.
|
||||||
|
type NetworkWatcher struct {
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new network monitor.
|
||||||
|
func New() *NetworkWatcher {
|
||||||
|
return &NetworkWatcher{}
|
||||||
|
}
|
133
client/internal/networkmonitor/monitor_bsd.go
Normal file
133
client/internal/networkmonitor/monitor_bsd.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||||
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := unix.Close(fd); err != nil {
|
||||||
|
log.Errorf("Network monitor: failed to close routing socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
n, err := unix.Read(fd, buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n < unix.SizeofRtMsghdr {
|
||||||
|
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
|
||||||
|
// handle interface state changes
|
||||||
|
case unix.RTM_IFINFO:
|
||||||
|
ifinfo, err := parseInterfaceMessage(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Network monitor: error parsing interface message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.Flags&unix.IFF_UP != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
|
||||||
|
callback()
|
||||||
|
|
||||||
|
// handle route changes
|
||||||
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
|
route, err := parseRouteMessage(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Network monitor: error parsing routing message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !route.Dst.Addr().IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := "<nil>"
|
||||||
|
if route.Interface != nil {
|
||||||
|
intf = route.Interface.Name
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case unix.RTM_ADD:
|
||||||
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
|
callback()
|
||||||
|
case unix.RTM_DELETE:
|
||||||
|
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||||
|
callback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := msgs[0].(*route.InterfaceMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := msgs[0].(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return routemanager.MsgToRoute(msg)
|
||||||
|
}
|
82
client/internal/networkmonitor/monitor_generic.go
Normal file
82
client/internal/networkmonitor/monitor_generic.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start begins watching for network changes and calls the callback function and stops when a change is detected.
|
||||||
|
func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
|
||||||
|
if nw.cancel != nil {
|
||||||
|
log.Warn("Network monitor: already running, stopping previous watcher")
|
||||||
|
nw.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
log.Info("Network monitor: not starting, context is already cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, nw.cancel = context.WithCancel(ctx)
|
||||||
|
defer nw.Stop()
|
||||||
|
|
||||||
|
var nexthop4, nexthop6 netip.Addr
|
||||||
|
var intf4, intf6 *net.Interface
|
||||||
|
|
||||||
|
operation := func() error {
|
||||||
|
var errv4, errv6 error
|
||||||
|
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
|
||||||
|
if errv4 != nil && errv6 != nil {
|
||||||
|
return errors.New("failed to get default next hops")
|
||||||
|
}
|
||||||
|
|
||||||
|
if errv4 == nil {
|
||||||
|
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
|
||||||
|
}
|
||||||
|
if errv6 == nil {
|
||||||
|
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// continue if either route was found
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||||
|
|
||||||
|
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||||
|
log.Errorf("Network monitor: failed to get default next hops: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// recover in case sys ops panic
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("Network monitor: failed to start: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the network monitor.
|
||||||
|
func (nw *NetworkWatcher) Stop() {
|
||||||
|
if nw.cancel != nil {
|
||||||
|
nw.cancel()
|
||||||
|
nw.cancel = nil
|
||||||
|
log.Info("Network monitor: stopped")
|
||||||
|
}
|
||||||
|
}
|
81
client/internal/networkmonitor/monitor_linux.go
Normal file
81
client/internal/networkmonitor/monitor_linux.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||||
|
if intfv4 == nil && intfv6 == nil {
|
||||||
|
return errors.New("no interfaces available")
|
||||||
|
}
|
||||||
|
|
||||||
|
linkChan := make(chan netlink.LinkUpdate)
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||||
|
return fmt.Errorf("subscribe to link updates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
routeChan := make(chan netlink.RouteUpdate)
|
||||||
|
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
|
||||||
|
return fmt.Errorf("subscribe to route updates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Network monitor: started")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
|
||||||
|
// handle interface state changes
|
||||||
|
case update := <-linkChan:
|
||||||
|
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch update.Header.Type {
|
||||||
|
case syscall.RTM_DELLINK:
|
||||||
|
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
case syscall.RTM_NEWLINK:
|
||||||
|
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
|
||||||
|
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle route changes
|
||||||
|
case route := <-routeChan:
|
||||||
|
// default route and main table
|
||||||
|
if route.Dst != nil || route.Table != syscall.RT_TABLE_MAIN {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch route.Type {
|
||||||
|
// triggered on added/replaced routes
|
||||||
|
case syscall.RTM_NEWROUTE:
|
||||||
|
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
case syscall.RTM_DELROUTE:
|
||||||
|
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
11
client/internal/networkmonitor/monitor_mobile.go
Normal file
11
client/internal/networkmonitor/monitor_mobile.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
func (nw *NetworkWatcher) Start(context.Context, func()) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nw *NetworkWatcher) Stop() {
|
||||||
|
}
|
215
client/internal/networkmonitor/monitor_windows.go
Normal file
215
client/internal/networkmonitor/monitor_windows.go
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
unreachable = 0
|
||||||
|
incomplete = 1
|
||||||
|
probe = 2
|
||||||
|
delay = 3
|
||||||
|
stale = 4
|
||||||
|
reachable = 5
|
||||||
|
permanent = 6
|
||||||
|
tbd = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
const interval = 10 * time.Second
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||||
|
var neighborv4, neighborv6 *routemanager.Neighbor
|
||||||
|
{
|
||||||
|
initialNeighbors, err := getNeighbors()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get neighbors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, ok := initialNeighbors[nexthopv4]; ok {
|
||||||
|
neighborv4 = &n
|
||||||
|
}
|
||||||
|
if n, ok := initialNeighbors[nexthopv6]; ok {
|
||||||
|
neighborv6 = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
|
||||||
|
callback()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func changed(
|
||||||
|
nexthopv4 netip.Addr,
|
||||||
|
intfv4 *net.Interface,
|
||||||
|
neighborv4 *routemanager.Neighbor,
|
||||||
|
nexthopv6 netip.Addr,
|
||||||
|
intfv6 *net.Interface,
|
||||||
|
neighborv6 *routemanager.Neighbor,
|
||||||
|
) bool {
|
||||||
|
neighbors, err := getNeighbors()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("network monitor: error fetching current neighbors: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, err := getRoutes()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("network monitor: error fetching current routes: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||||
|
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
|
||||||
|
if !nexthop.IsValid() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var unspec netip.Prefix
|
||||||
|
if nexthop.Is6() {
|
||||||
|
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
} else {
|
||||||
|
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r, ok := routes[unspec]; ok {
|
||||||
|
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
|
||||||
|
intf := "<nil>"
|
||||||
|
if r.Interface != nil {
|
||||||
|
intf = r.Interface.Name
|
||||||
|
}
|
||||||
|
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("network monitor: default route is gone")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
|
||||||
|
if neighbor == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||||
|
if n, ok := neighbors[nexthop]; ok {
|
||||||
|
if n.State != reachable && n.State != permanent {
|
||||||
|
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||||
|
return true
|
||||||
|
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||||
|
log.Infof(
|
||||||
|
"network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
|
||||||
|
neighbor.IPAddress,
|
||||||
|
neighbor.LinkLayerAddress,
|
||||||
|
neighbor.InterfaceAlias,
|
||||||
|
neighbor.InterfaceIndex,
|
||||||
|
n.InterfaceAlias,
|
||||||
|
n.InterfaceIndex,
|
||||||
|
stateFromInt(n.State),
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||||
|
entries, err := routemanager.GetNeighbors()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
neighbours[entry.IPAddress] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
return neighbours, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
|
||||||
|
entries, err := routemanager.GetRoutes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
routes[entry.Destination] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateFromInt(state uint8) string {
|
||||||
|
switch state {
|
||||||
|
case unreachable:
|
||||||
|
return "unreachable"
|
||||||
|
case incomplete:
|
||||||
|
return "incomplete"
|
||||||
|
case probe:
|
||||||
|
return "probe"
|
||||||
|
case delay:
|
||||||
|
return "delay"
|
||||||
|
case stale:
|
||||||
|
return "stale"
|
||||||
|
case reachable:
|
||||||
|
return "reachable"
|
||||||
|
case permanent:
|
||||||
|
return "permanent"
|
||||||
|
case tbd:
|
||||||
|
return "tbd"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareIntf(a, b *net.Interface) int {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if a == nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if b == nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return a.Index - b.Index
|
||||||
|
}
|
@ -276,7 +276,7 @@ func (conn *Conn) candidateTypes() []ice.CandidateType {
|
|||||||
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
||||||
// Blocks until connection has been closed or connection timeout.
|
// Blocks until connection has been closed or connection timeout.
|
||||||
// ConnStatus will be set accordingly
|
// ConnStatus will be set accordingly
|
||||||
func (conn *Conn) Open() error {
|
func (conn *Conn) Open(ctx context.Context) error {
|
||||||
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
log.Debugf("trying to connect to peer %s", conn.config.Key)
|
||||||
|
|
||||||
peerState := State{
|
peerState := State{
|
||||||
@ -336,7 +336,7 @@ func (conn *Conn) Open() error {
|
|||||||
// at this point we received offer/answer and we are ready to gather candidates
|
// at this point we received offer/answer and we are ready to gather candidates
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
conn.status = StatusConnecting
|
conn.status = StatusConnecting
|
||||||
conn.ctx, conn.notifyDisconnected = context.WithCancel(context.Background())
|
conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx)
|
||||||
defer conn.notifyDisconnected()
|
defer conn.notifyDisconnected()
|
||||||
conn.mu.Unlock()
|
conn.mu.Unlock()
|
||||||
|
|
||||||
@ -423,7 +423,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
var endpoint net.Addr
|
var endpoint net.Addr
|
||||||
if isRelayCandidate(pair.Local) {
|
if isRelayCandidate(pair.Local) {
|
||||||
log.Debugf("setup relay connection")
|
log.Debugf("setup relay connection")
|
||||||
conn.wgProxy = conn.wgProxyFactory.GetProxy()
|
conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx)
|
||||||
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -448,9 +448,11 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if conn.wgProxy != nil {
|
if conn.wgProxy != nil {
|
||||||
_ = conn.wgProxy.CloseConn()
|
if err := conn.wgProxy.CloseConn(); err != nil {
|
||||||
|
log.Warnf("Failed to close turn connection: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, fmt.Errorf("update peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.status = StatusConnected
|
conn.status = StatusConnected
|
||||||
@ -730,7 +732,7 @@ func (conn *Conn) Close() error {
|
|||||||
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
|
// before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
|
||||||
// engine adds a new Conn for 4 and 5
|
// engine adds a new Conn for 4 and 5
|
||||||
// therefore peer 4 has 2 Conn objects
|
// therefore peer 4 has 2 Conn objects
|
||||||
log.Warnf("connection has been already closed or attempted closing not started coonection %s", conn.config.Key)
|
log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key)
|
||||||
return NewConnectionAlreadyClosed(conn.config.Key)
|
return NewConnectionAlreadyClosed(conn.config.Key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -35,7 +36,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@ -50,7 +51,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@ -87,7 +88,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@ -123,7 +124,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
@ -153,7 +154,7 @@ func TestConn_Status(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConn_Close(t *testing.T) {
|
func TestConn_Close(t *testing.T) {
|
||||||
wgProxyFactory := wgproxy.NewFactory(connConf.LocalWgPort)
|
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wgProxyFactory.Free()
|
_ = wgProxyFactory.Free()
|
||||||
}()
|
}()
|
||||||
|
@ -35,7 +35,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|||||||
addr = netip.IPv6Unspecified()
|
addr = netip.IPv6Unspecified()
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultGateway, _, err := getNextHop(addr)
|
defaultGateway, _, err := GetNextHop(addr)
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
return fmt.Errorf("get existing route gateway: %s", err)
|
return fmt.Errorf("get existing route gateway: %s", err)
|
||||||
}
|
}
|
||||||
@ -60,7 +60,7 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
gatewayHop, intf, err := getNextHop(defaultGateway)
|
gatewayHop, intf, err := GetNextHop(defaultGateway)
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||||
}
|
}
|
||||||
@ -69,14 +69,14 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|||||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||||
r, err := netroute.New()
|
r, err := netroute.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||||
}
|
}
|
||||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get route for %s: %v", ip, err)
|
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||||
return netip.Addr{}, nil, ErrRouteNotFound
|
return netip.Addr{}, nil, ErrRouteNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||||
nexthop, intf, err := getNextHop(addr)
|
nexthop, intf, err := GetNextHop(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||||
}
|
}
|
||||||
@ -319,11 +319,11 @@ func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||||
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||||
}
|
}
|
||||||
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"errors"
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -15,17 +16,22 @@ import (
|
|||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Route struct {
|
||||||
|
Dst netip.Prefix
|
||||||
|
Gw netip.Addr
|
||||||
|
Interface *net.Interface
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: fix here with retry and backoff
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
tab, err := retryFetchRIB()
|
tab, err := retryFetchRIB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("fetch RIB: %v", err)
|
||||||
}
|
}
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
m := msg.(*route.RouteMessage)
|
m := msg.(*route.RouteMessage)
|
||||||
@ -33,7 +39,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
if m.Version < 3 || m.Version > 5 {
|
if m.Version < 3 || m.Version > 5 {
|
||||||
return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
|
return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
|
||||||
}
|
}
|
||||||
if m.Type != 4 /* RTM_GET */ {
|
if m.Type != syscall.RTM_GET {
|
||||||
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,28 +48,13 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.Addrs) < 3 {
|
route, err := MsgToRoute(m)
|
||||||
log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs)
|
if err != nil {
|
||||||
|
log.Warnf("Failed to parse route message: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if route.Dst.IsValid() {
|
||||||
addr, ok := toNetIPAddr(m.Addrs[0])
|
prefixList = append(prefixList, route.Dst)
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
cidr := 32
|
|
||||||
if mask := m.Addrs[2]; mask != nil {
|
|
||||||
cidr, ok = toCIDR(mask)
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("Unexpected RIB message Addrs[2]: %v", mask)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
|
||||||
if routePrefix.IsValid() {
|
|
||||||
prefixList = append(prefixList, routePrefix)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
@ -75,7 +66,7 @@ func retryFetchRIB() ([]byte, error) {
|
|||||||
var err error
|
var err error
|
||||||
out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||||
if errors.Is(err, syscall.ENOMEM) {
|
if errors.Is(err, syscall.ENOMEM) {
|
||||||
log.Debug("retrying fetchRIB due to 'cannot allocate memory' error")
|
log.Debug("~etrying fetchRIB due to 'cannot allocate memory' error")
|
||||||
return err
|
return err
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return backoff.Permanent(err)
|
return backoff.Permanent(err)
|
||||||
@ -95,22 +86,74 @@ func retryFetchRIB() ([]byte, error) {
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toNetIPAddr(a route.Addr) (netip.Addr, bool) {
|
func toNetIP(a route.Addr) netip.Addr {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
return netip.AddrFrom4(t.IP), true
|
return netip.AddrFrom4(t.IP)
|
||||||
|
case *route.Inet6Addr:
|
||||||
|
ip := netip.AddrFrom16(t.IP)
|
||||||
|
if t.ZoneID != 0 {
|
||||||
|
ip.WithZone(strconv.Itoa(t.ZoneID))
|
||||||
|
}
|
||||||
|
return ip
|
||||||
default:
|
default:
|
||||||
return netip.Addr{}, false
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toCIDR(a route.Addr) (int, bool) {
|
func ones(a route.Addr) (int, error) {
|
||||||
switch t := a.(type) {
|
switch t := a.(type) {
|
||||||
case *route.Inet4Addr:
|
case *route.Inet4Addr:
|
||||||
mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
mask, _ := net.IPMask(t.IP[:]).Size()
|
||||||
cidr, _ := mask.Size()
|
return mask, nil
|
||||||
return cidr, true
|
case *route.Inet6Addr:
|
||||||
|
mask, _ := net.IPMask(t.IP[:]).Size()
|
||||||
|
return mask, nil
|
||||||
default:
|
default:
|
||||||
return 0, false
|
return 0, fmt.Errorf("unexpected address type: %T", a)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
|
||||||
|
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
|
||||||
|
|
||||||
|
addr := toNetIP(dstIP)
|
||||||
|
|
||||||
|
var nexthopAddr netip.Addr
|
||||||
|
var nexthopIntf *net.Interface
|
||||||
|
|
||||||
|
switch t := nexthop.(type) {
|
||||||
|
case *route.Inet4Addr, *route.Inet6Addr:
|
||||||
|
nexthopAddr = toNetIP(t)
|
||||||
|
case *route.LinkAddr:
|
||||||
|
nexthopIntf = &net.Interface{
|
||||||
|
Index: t.Index,
|
||||||
|
Name: t.Name,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unexpected next hop type: %T", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefix netip.Prefix
|
||||||
|
|
||||||
|
if dstMask == nil {
|
||||||
|
if addr.Is4() {
|
||||||
|
prefix = netip.PrefixFrom(addr, 32)
|
||||||
|
} else {
|
||||||
|
prefix = netip.PrefixFrom(addr, 128)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bits, err := ones(dstMask)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse mask: %v", dstMask)
|
||||||
|
}
|
||||||
|
prefix = netip.PrefixFrom(addr, bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Route{
|
||||||
|
Dst: prefix,
|
||||||
|
Gw: nexthopAddr,
|
||||||
|
Interface: nexthopIntf,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
57
client/internal/routemanager/systemops_bsd_test.go
Normal file
57
client/internal/routemanager/systemops_bsd_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBits(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr route.Addr
|
||||||
|
want int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 all ones",
|
||||||
|
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
|
||||||
|
want: 32,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 normal mask",
|
||||||
|
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
|
||||||
|
want: 24,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 all ones",
|
||||||
|
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
|
||||||
|
want: 128,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 normal mask",
|
||||||
|
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||||
|
want: 64,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unsupported type",
|
||||||
|
addr: &route.LinkAddr{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := ones(tt.addr)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -87,10 +87,10 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
err = removeVPNRoute(testCase.prefix, intf)
|
err = removeVPNRoute(testCase.prefix, intf)
|
||||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
require.NoError(t, err, "GetNextHop should not return err")
|
||||||
|
|
||||||
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if testCase.shouldBeRemoved {
|
if testCase.shouldBeRemoved {
|
||||||
@ -104,7 +104,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetNextHop(t *testing.T) {
|
func TestGetNextHop(t *testing.T) {
|
||||||
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
}
|
}
|
||||||
@ -130,7 +130,7 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localIP, _, err := getNextHop(testingPrefix.Addr())
|
localIP, _, err := GetNextHop(testingPrefix.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error: ", err)
|
t.Fatal("shouldn't return error: ", err)
|
||||||
}
|
}
|
||||||
@ -146,7 +146,7 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||||
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
t.Log("defaultGateway: ", defaultGateway)
|
t.Log("defaultGateway: ", defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
@ -410,8 +410,8 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(prefix.Addr())
|
prefixGateway, _, err := GetNextHop(prefix.Addr())
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
require.NoError(t, err, "GetNextHop should not return err")
|
||||||
if invert {
|
if invert {
|
||||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||||
} else {
|
} else {
|
||||||
|
@ -20,9 +20,36 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Win32_IP4RouteTable struct {
|
type MSFT_NetRoute struct {
|
||||||
Destination string
|
DestinationPrefix string
|
||||||
Mask string
|
NextHop string
|
||||||
|
InterfaceIndex int32
|
||||||
|
InterfaceAlias string
|
||||||
|
AddressFamily uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type Route struct {
|
||||||
|
Destination netip.Prefix
|
||||||
|
Nexthop netip.Addr
|
||||||
|
Interface *net.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
type MSFT_NetNeighbor struct {
|
||||||
|
IPAddress string
|
||||||
|
LinkLayerAddress string
|
||||||
|
State uint8
|
||||||
|
AddressFamily uint16
|
||||||
|
InterfaceIndex uint32
|
||||||
|
InterfaceAlias string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Neighbor struct {
|
||||||
|
IPAddress netip.Addr
|
||||||
|
LinkLayerAddress string
|
||||||
|
State uint8
|
||||||
|
AddressFamily uint16
|
||||||
|
InterfaceIndex uint32
|
||||||
|
InterfaceAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
@ -43,44 +70,92 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
mux.Lock()
|
mux.Lock()
|
||||||
defer mux.Unlock()
|
defer mux.Unlock()
|
||||||
|
|
||||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
|
||||||
|
|
||||||
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
|
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
|
||||||
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
|
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var routes []Win32_IP4RouteTable
|
routes, err := GetRoutes()
|
||||||
err := wmi.Query(query, &routes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get routes: %w", err)
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
prefixList = nil
|
prefixList = nil
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
addr, err := netip.ParseAddr(route.Destination)
|
prefixList = append(prefixList, route.Destination)
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
maskSlice := net.ParseIP(route.Mask).To4()
|
|
||||||
if maskSlice == nil {
|
|
||||||
log.Warnf("Unable to parse route mask %s", route.Mask)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
|
||||||
cidr, _ := mask.Size()
|
|
||||||
|
|
||||||
routePrefix := netip.PrefixFrom(addr, cidr)
|
|
||||||
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
|
|
||||||
prefixList = append(prefixList, routePrefix)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastUpdate = time.Now()
|
lastUpdate = time.Now()
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRoutes() ([]Route, error) {
|
||||||
|
var entries []MSFT_NetRoute
|
||||||
|
|
||||||
|
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||||
|
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||||
|
return nil, fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var routes []Route
|
||||||
|
for _, entry := range entries {
|
||||||
|
dest, err := netip.ParsePrefix(entry.DestinationPrefix)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse route destination %s: %v", entry.DestinationPrefix, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nexthop, err := netip.ParseAddr(entry.NextHop)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse route next hop %s: %v", entry.NextHop, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var intf *net.Interface
|
||||||
|
if entry.InterfaceIndex != 0 {
|
||||||
|
intf = &net.Interface{
|
||||||
|
Index: int(entry.InterfaceIndex),
|
||||||
|
Name: entry.InterfaceAlias,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = append(routes, Route{
|
||||||
|
Destination: dest,
|
||||||
|
Nexthop: nexthop,
|
||||||
|
Interface: intf,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetNeighbors() ([]Neighbor, error) {
|
||||||
|
var entries []MSFT_NetNeighbor
|
||||||
|
query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor`
|
||||||
|
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var neighbors []Neighbor
|
||||||
|
for _, entry := range entries {
|
||||||
|
addr, err := netip.ParseAddr(entry.IPAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
neighbors = append(neighbors, Neighbor{
|
||||||
|
IPAddress: addr,
|
||||||
|
LinkLayerAddress: entry.LinkLayerAddress,
|
||||||
|
State: entry.State,
|
||||||
|
AddressFamily: entry.AddressFamily,
|
||||||
|
InterfaceIndex: entry.InterfaceIndex,
|
||||||
|
InterfaceAlias: entry.InterfaceAlias,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return neighbors, nil
|
||||||
|
}
|
||||||
|
|
||||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||||
args := []string{"add", prefix.String()}
|
args := []string{"add", prefix.String()}
|
||||||
|
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
type Factory struct {
|
type Factory struct {
|
||||||
wgPort int
|
wgPort int
|
||||||
ebpfProxy Proxy
|
ebpfProxy Proxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Factory) GetProxy() Proxy {
|
func (w *Factory) GetProxy(ctx context.Context) Proxy {
|
||||||
if w.ebpfProxy != nil {
|
if w.ebpfProxy != nil {
|
||||||
return w.ebpfProxy
|
return w.ebpfProxy
|
||||||
}
|
}
|
||||||
return NewWGUserSpaceProxy(w.wgPort)
|
return NewWGUserSpaceProxy(ctx, w.wgPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Factory) Free() error {
|
func (w *Factory) Free() error {
|
||||||
|
@ -3,14 +3,16 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewFactory(wgPort int) *Factory {
|
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||||
f := &Factory{wgPort: wgPort}
|
f := &Factory{wgPort: wgPort}
|
||||||
|
|
||||||
ebpfProxy := NewWGEBPFProxy(wgPort)
|
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||||
err := ebpfProxy.Listen()
|
err := ebpfProxy.listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||||
return f
|
return f
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
func NewFactory(wgPort int) *Factory {
|
import "context"
|
||||||
|
|
||||||
|
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||||
return &Factory{wgPort: wgPort}
|
return &Factory{wgPort: wgPort}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
// Proxy is a transfer layer between the Turn connection and the WireGuard
|
||||||
type Proxy interface {
|
type Proxy interface {
|
||||||
AddTurnConn(urnConn net.Conn) (net.Addr, error)
|
AddTurnConn(turnConn net.Conn) (net.Addr, error)
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
Free() error
|
Free() error
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -22,7 +23,11 @@ import (
|
|||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
ebpfManager ebpfMgr.Manager
|
ebpfManager ebpfMgr.Manager
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
lastUsedPort uint16
|
lastUsedPort uint16
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
|
|
||||||
@ -34,7 +39,7 @@ type WGEBPFProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||||
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
|
||||||
log.Debugf("instantiate ebpf proxy")
|
log.Debugf("instantiate ebpf proxy")
|
||||||
wgProxy := &WGEBPFProxy{
|
wgProxy := &WGEBPFProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
@ -42,11 +47,13 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
|||||||
lastUsedPort: 0,
|
lastUsedPort: 0,
|
||||||
turnConnStore: make(map[uint16]net.Conn),
|
turnConnStore: make(map[uint16]net.Conn),
|
||||||
}
|
}
|
||||||
|
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
return wgProxy
|
return wgProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen load ebpf program and listen the proxy
|
// listen load ebpf program and listen the proxy
|
||||||
func (p *WGEBPFProxy) Listen() error {
|
func (p *WGEBPFProxy) listen() error {
|
||||||
pl := portLookup{}
|
pl := portLookup{}
|
||||||
wgPorxyPort, err := pl.searchFreePort()
|
wgPorxyPort, err := pl.searchFreePort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -72,7 +79,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cErr := p.Free()
|
cErr := p.Free()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
log.Errorf("failed to close the wgproxy: %s", cErr)
|
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -102,6 +109,7 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
|||||||
|
|
||||||
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
||||||
func (p *WGEBPFProxy) CloseConn() error {
|
func (p *WGEBPFProxy) CloseConn() error {
|
||||||
|
p.cancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,19 +139,27 @@ func (p *WGEBPFProxy) Free() error {
|
|||||||
|
|
||||||
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
|
var err error
|
||||||
|
defer func() {
|
||||||
|
p.removeTurnConn(endpointPort)
|
||||||
|
}()
|
||||||
for {
|
for {
|
||||||
n, err := remoteConn.Read(buf)
|
select {
|
||||||
if err != nil {
|
case <-p.ctx.Done():
|
||||||
if err != io.EOF {
|
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
|
||||||
}
|
|
||||||
p.removeTurnConn(endpointPort)
|
|
||||||
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
|
|
||||||
return
|
return
|
||||||
}
|
default:
|
||||||
err = p.sendPkg(buf[:n], endpointPort)
|
var n int
|
||||||
if err != nil {
|
n, err = remoteConn.Read(buf)
|
||||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = p.sendPkg(buf[:n], endpointPort)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -152,23 +168,28 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
|||||||
func (p *WGEBPFProxy) proxyToRemote() {
|
func (p *WGEBPFProxy) proxyToRemote() {
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
for {
|
for {
|
||||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
select {
|
||||||
if err != nil {
|
case <-p.ctx.Done():
|
||||||
log.Errorf("failed to read UDP pkg from WG: %s", err)
|
|
||||||
return
|
return
|
||||||
}
|
default:
|
||||||
|
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read UDP pkg from WG: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.turnConnMutex.Lock()
|
p.turnConnMutex.Lock()
|
||||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||||
p.turnConnMutex.Unlock()
|
p.turnConnMutex.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Infof("turn conn not found by port: %d", addr.Port)
|
log.Infof("turn conn not found by port: %d", addr.Port)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Write(buf[:n])
|
_, err = conn.Write(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
|
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -266,15 +287,17 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
|
|||||||
|
|
||||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
layerBuffer := gopacket.NewSerializeBuffer()
|
layerBuffer := gopacket.NewSerializeBuffer()
|
||||||
|
|
||||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("serialize layers: %w", err)
|
||||||
}
|
}
|
||||||
_, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost})
|
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||||
return err
|
return fmt.Errorf("write to raw conn: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,12 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||||
wgProxy := NewWGEBPFProxy(1)
|
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||||
|
|
||||||
p, _ := wgProxy.storeTurnConn(nil)
|
p, _ := wgProxy.storeTurnConn(nil)
|
||||||
if p != 1 {
|
if p != 1 {
|
||||||
@ -27,7 +28,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||||
wgProxy := NewWGEBPFProxy(1)
|
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||||
|
|
||||||
_, _ = wgProxy.storeTurnConn(nil)
|
_, _ = wgProxy.storeTurnConn(nil)
|
||||||
wgProxy.lastUsedPort = 65535
|
wgProxy.lastUsedPort = 65535
|
||||||
@ -43,7 +44,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||||
wgProxy := NewWGEBPFProxy(1)
|
wgProxy := NewWGEBPFProxy(context.Background(), 1)
|
||||||
|
|
||||||
for i := 0; i < 65535; i++ {
|
for i := 0; i < 65535; i++ {
|
||||||
_, _ = wgProxy.storeTurnConn(nil)
|
_, _ = wgProxy.storeTurnConn(nil)
|
||||||
|
@ -21,21 +21,21 @@ type WGUserSpaceProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
|
||||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
|
||||||
log.Debugf("instantiate new userspace proxy")
|
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||||
p := &WGUserSpaceProxy{
|
p := &WGUserSpaceProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
}
|
}
|
||||||
p.ctx, p.cancel = context.WithCancel(context.Background())
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTurnConn start the proxy with the given remote conn
|
// AddTurnConn start the proxy with the given remote conn
|
||||||
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = turnConn
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -120,6 +120,7 @@ type LoginRequest struct {
|
|||||||
ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"`
|
ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"`
|
||||||
RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"`
|
RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"`
|
||||||
ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
|
ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
|
||||||
|
NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *LoginRequest) Reset() {
|
func (x *LoginRequest) Reset() {
|
||||||
@ -274,6 +275,13 @@ func (x *LoginRequest) GetExtraIFaceBlacklist() []string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *LoginRequest) GetNetworkMonitor() bool {
|
||||||
|
if x != nil && x.NetworkMonitor != nil {
|
||||||
|
return *x.NetworkMonitor
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
state protoimpl.MessageState
|
state protoimpl.MessageState
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
@ -1893,7 +1901,7 @@ var file_daemon_proto_rawDesc = []byte{
|
|||||||
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
||||||
0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
|
0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
|
||||||
0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74,
|
0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74,
|
||||||
0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8f, 0x07, 0x0a, 0x0c, 0x4c, 0x6f,
|
0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xcf, 0x07, 0x0a, 0x0c, 0x4c, 0x6f,
|
||||||
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65,
|
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65,
|
||||||
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65,
|
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65,
|
||||||
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61,
|
0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61,
|
||||||
@ -1941,16 +1949,20 @@ var file_daemon_proto_rawDesc = []byte{
|
|||||||
0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63,
|
0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63,
|
||||||
0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09,
|
0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09,
|
||||||
0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63,
|
0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63,
|
||||||
0x6b, 0x6c, 0x69, 0x73, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
|
0x6b, 0x6c, 0x69, 0x73, 0x74, 0x12, 0x2b, 0x0a, 0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
|
||||||
0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69,
|
0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x48, 0x07, 0x52,
|
||||||
0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e,
|
0x0e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x88,
|
||||||
0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17,
|
0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||||
0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68,
|
0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65,
|
||||||
0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61,
|
0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69,
|
||||||
0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13,
|
0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f,
|
||||||
0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f,
|
0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65,
|
||||||
0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
|
0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65,
|
||||||
0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d,
|
0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f,
|
||||||
|
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64,
|
||||||
|
0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65,
|
||||||
|
0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74,
|
||||||
|
0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x22, 0xb5, 0x01, 0x0a, 0x0d,
|
||||||
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a,
|
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a,
|
||||||
0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01,
|
0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01,
|
||||||
0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
|
0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
|
||||||
|
@ -87,6 +87,8 @@ message LoginRequest {
|
|||||||
optional bool rosenpassPermissive = 16;
|
optional bool rosenpassPermissive = 16;
|
||||||
|
|
||||||
repeated string extraIFaceBlacklist = 17;
|
repeated string extraIFaceBlacklist = 17;
|
||||||
|
|
||||||
|
optional bool networkMonitor = 18;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
|
@ -358,6 +358,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
s.latestConfigInput.WireguardPort = &port
|
s.latestConfigInput.WireguardPort = &port
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.NetworkMonitor != nil {
|
||||||
|
inputConfig.NetworkMonitor = msg.NetworkMonitor
|
||||||
|
s.latestConfigInput.NetworkMonitor = msg.NetworkMonitor
|
||||||
|
}
|
||||||
|
|
||||||
if len(msg.ExtraIFaceBlacklist) > 0 {
|
if len(msg.ExtraIFaceBlacklist) > 0 {
|
||||||
inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||||
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client interface {
|
type Client interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKey() (*wgtypes.Key, error)
|
GetServerPublicKey() (*wgtypes.Key, error)
|
||||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
mgmt "github.com/netbirdio/netbird/management/server"
|
mgmt "github.com/netbirdio/netbird/management/server"
|
||||||
@ -255,7 +256,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
ch := make(chan *mgmtProto.SyncResponse, 1)
|
ch := make(chan *mgmtProto.SyncResponse, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
|
err = client.Sync(context.Background(), func(msg *mgmtProto.SyncResponse) error {
|
||||||
ch <- msg
|
ch <- msg
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -113,8 +113,8 @@ func (c *GrpcClient) ready() bool {
|
|||||||
|
|
||||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||||
// Blocking request. The result will be sent via msgHandler callback function
|
// Blocking request. The result will be sent via msgHandler callback function
|
||||||
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
func (c *GrpcClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
backOff := defaultBackoff(c.ctx)
|
backOff := defaultBackoff(ctx)
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
log.Debugf("management connection state %v", c.conn.GetState())
|
log.Debugf("management connection state %v", c.conn.GetState())
|
||||||
@ -123,7 +123,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
if connState == connectivity.Shutdown {
|
if connState == connectivity.Shutdown {
|
||||||
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
|
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
|
||||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||||
c.conn.WaitForStateChange(c.ctx, connState)
|
c.conn.WaitForStateChange(ctx, connState)
|
||||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
ctx, cancelStream := context.WithCancel(ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
stream, err := c.connectToStream(ctx, *serverPubKey)
|
stream, err := c.connectToStream(ctx, *serverPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@ -9,7 +11,7 @@ import (
|
|||||||
|
|
||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
CloseFunc func() error
|
CloseFunc func() error
|
||||||
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
|
SyncFunc func(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
|
||||||
@ -28,11 +30,11 @@ func (m *MockClient) Close() error {
|
|||||||
return m.CloseFunc()
|
return m.CloseFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
|
func (m *MockClient) Sync(ctx context.Context, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||||
if m.SyncFunc == nil {
|
if m.SyncFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.SyncFunc(msgHandler)
|
return m.SyncFunc(ctx, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
@ -33,7 +34,7 @@ type Client interface {
|
|||||||
io.Closer
|
io.Closer
|
||||||
StreamConnected() bool
|
StreamConnected() bool
|
||||||
GetStatus() Status
|
GetStatus() Status
|
||||||
Receive(msgHandler func(msg *proto.Message) error) error
|
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||||
Ready() bool
|
Ready() bool
|
||||||
IsHealthy() bool
|
IsHealthy() bool
|
||||||
WaitStreamConnected()
|
WaitStreamConnected()
|
||||||
|
@ -55,7 +55,7 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
keyA, _ := wgtypes.GenerateKey()
|
keyA, _ := wgtypes.GenerateKey()
|
||||||
clientA := createSignalClient(addr, keyA)
|
clientA := createSignalClient(addr, keyA)
|
||||||
go func() {
|
go func() {
|
||||||
err := clientA.Receive(func(msg *sigProto.Message) error {
|
err := clientA.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||||
payloadReceivedOnA = msg.GetBody().GetPayload()
|
payloadReceivedOnA = msg.GetBody().GetPayload()
|
||||||
featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported()
|
featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported()
|
||||||
msgReceived.Done()
|
msgReceived.Done()
|
||||||
@ -72,7 +72,7 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
clientB := createSignalClient(addr, keyB)
|
clientB := createSignalClient(addr, keyB)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := clientB.Receive(func(msg *sigProto.Message) error {
|
err := clientB.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||||
payloadReceivedOnB = msg.GetBody().GetPayload()
|
payloadReceivedOnB = msg.GetBody().GetPayload()
|
||||||
featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported()
|
featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported()
|
||||||
err := clientB.Send(&sigProto.Message{
|
err := clientB.Send(&sigProto.Message{
|
||||||
@ -122,7 +122,7 @@ var _ = Describe("GrpcClient", func() {
|
|||||||
key, _ := wgtypes.GenerateKey()
|
key, _ := wgtypes.GenerateKey()
|
||||||
client := createSignalClient(addr, key)
|
client := createSignalClient(addr, key)
|
||||||
go func() {
|
go func() {
|
||||||
err := client.Receive(func(msg *sigProto.Message) error {
|
err := client.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -126,9 +126,9 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
|
|||||||
// The messages will be handled by msgHandler function provided.
|
// The messages will be handled by msgHandler function provided.
|
||||||
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||||
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
||||||
func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
|
||||||
|
|
||||||
var backOff = defaultBackoff(c.ctx)
|
var backOff = defaultBackoff(ctx)
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
|
|
||||||
@ -139,13 +139,13 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
|||||||
if connState == connectivity.Shutdown {
|
if connState == connectivity.Shutdown {
|
||||||
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
|
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
|
||||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||||
c.signalConn.WaitForStateChange(c.ctx, connState)
|
c.signalConn.WaitForStateChange(ctx, connState)
|
||||||
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to Signal stream identifying ourselves with a public WireGuard key
|
// connect to Signal stream identifying ourselves with a public WireGuard key
|
||||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
ctx, cancelStream := context.WithCancel(ctx)
|
||||||
defer cancelStream()
|
defer cancelStream()
|
||||||
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -10,7 +12,7 @@ type MockClient struct {
|
|||||||
StreamConnectedFunc func() bool
|
StreamConnectedFunc func() bool
|
||||||
ReadyFunc func() bool
|
ReadyFunc func() bool
|
||||||
WaitStreamConnectedFunc func()
|
WaitStreamConnectedFunc func()
|
||||||
ReceiveFunc func(msgHandler func(msg *proto.Message) error) error
|
ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||||
SendToStreamFunc func(msg *proto.EncryptedMessage) error
|
SendToStreamFunc func(msg *proto.EncryptedMessage) error
|
||||||
SendFunc func(msg *proto.Message) error
|
SendFunc func(msg *proto.Message) error
|
||||||
}
|
}
|
||||||
@ -54,11 +56,11 @@ func (sm *MockClient) WaitStreamConnected() {
|
|||||||
sm.WaitStreamConnectedFunc()
|
sm.WaitStreamConnectedFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sm *MockClient) Receive(msgHandler func(msg *proto.Message) error) error {
|
func (sm *MockClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
|
||||||
if sm.ReceiveFunc == nil {
|
if sm.ReceiveFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return sm.ReceiveFunc(msgHandler)
|
return sm.ReceiveFunc(ctx, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
|
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||||
|
Loading…
Reference in New Issue
Block a user