Add route selection to iOS (#1944)

This commit is contained in:
pascal-fischer 2024-05-10 10:47:16 +02:00 committed by GitHub
parent 263abe4862
commit 272ade07a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 397 additions and 110 deletions

View File

@ -101,7 +101,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@ -126,7 +127,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@ -152,7 +152,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
return connectClient.Run()
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@ -7,6 +7,7 @@ import (
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@ -29,30 +30,45 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// RunClient with main logic. type ConnectClient struct {
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error { ctx context.Context
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil) config *Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
} }
// RunClientWithProbes runs the client's main logic with probes attached func NewConnectClient(
func RunClientWithProbes(
ctx context.Context, ctx context.Context,
config *Config, config *Config,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
mgmProbe *Probe, mgmProbe *Probe,
signalProbe *Probe, signalProbe *Probe,
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
engineChan chan<- *Engine,
) error { ) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan) return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
} }
// RunClientMobile with main logic on mobile system // RunOnAndroid with main logic on mobile system
func RunClientMobile( func (c *ConnectClient) RunOnAndroid(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
tunAdapter iface.TunAdapter, tunAdapter iface.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
@ -67,13 +83,10 @@ func RunClientMobile(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
} }
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil) return c.run(mobileDependency, nil, nil, nil, nil)
} }
func RunClientiOS( func (c *ConnectClient) RunOniOS(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
fileDescriptor int32, fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
@ -83,19 +96,15 @@ func RunClientiOS(
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
} }
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil) return c.run(mobileDependency, nil, nil, nil, nil)
} }
func runClient( func (c *ConnectClient) run(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
mobileDependency MobileDependency, mobileDependency MobileDependency,
mgmProbe *Probe, mgmProbe *Probe,
signalProbe *Probe, signalProbe *Probe,
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
engineChan chan<- *Engine,
) error { ) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@ -107,7 +116,7 @@ func runClient(
// Check if client was not shut down in a clean way and restore DNS config if required. // Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config. // Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(config.WgIface); err != nil { if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err) log.Errorf("checking unclean shutdown error: %s", err)
} }
@ -121,7 +130,7 @@ func runClient(
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
state := CtxGetState(ctx) state := CtxGetState(c.ctx)
defer func() { defer func() {
s, err := state.Status() s, err := state.Status()
if err != nil || s != StatusNeedsLogin { if err != nil || s != StatusNeedsLogin {
@ -130,49 +139,49 @@ func runClient(
}() }()
wrapErr := state.Wrap wrapErr := state.Wrap
myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey) myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey)
if err != nil { if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error()) log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error())
return wrapErr(err) return wrapErr(err)
} }
var mgmTlsEnabled bool var mgmTlsEnabled bool
if config.ManagementURL.Scheme == "https" { if c.config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true mgmTlsEnabled = true
} }
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey))
if err != nil { if err != nil {
return err return err
} }
defer statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
select { select {
case <-ctx.Done(): case <-c.ctx.Done():
return nil return nil
default: default:
} }
state.Set(StatusConnecting) state.Set(StatusConnecting)
engineCtx, cancel := context.WithCancel(ctx) engineCtx, cancel := context.WithCancel(c.ctx)
defer func() { defer func() {
statusRecorder.MarkManagementDisconnected(state.err) c.statusRecorder.MarkManagementDisconnected(state.err)
statusRecorder.CleanLocalPeerState() c.statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
log.Debugf("connecting to the Management service %s", config.ManagementURL.Host) log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil { if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
} }
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder) mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier) mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", config.ManagementURL.Host) log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() { defer func() {
err = mgmClient.Close() err = mgmClient.Close()
if err != nil { if err != nil {
@ -190,7 +199,7 @@ func runClient(
} }
return wrapErr(err) return wrapErr(err)
} }
statusRecorder.MarkManagementConnected() c.statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
@ -199,18 +208,18 @@ func runClient(
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
statusRecorder.UpdateLocalPeerState(localPeerState) c.statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s", signalURL := fmt.Sprintf("%s://%s",
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()), strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
) )
statusRecorder.UpdateSignalAddress(signalURL) c.statusRecorder.UpdateSignalAddress(signalURL)
statusRecorder.MarkSignalDisconnected(nil) c.statusRecorder.MarkSignalDisconnected(nil)
defer func() { defer func() {
statusRecorder.MarkSignalDisconnected(state.err) c.statusRecorder.MarkSignalDisconnected(state.err)
}() }()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@ -226,42 +235,38 @@ func runClient(
} }
}() }()
signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder) signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder)
signalClient.SetConnStateListener(signalNotifier) signalClient.SetConnStateListener(signalNotifier)
statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig) engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) c.engineMutex.Lock()
err = engine.Start() c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
c.engineMutex.Unlock()
err = c.engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
if engineChan != nil {
engineChan <- engine
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
<-engineCtx.Done() <-engineCtx.Done()
statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()
if engineChan != nil { err = c.engine.Stop()
engineChan <- nil
}
err = engine.Stop()
if err != nil { if err != nil {
log.Errorf("failed stopping engine %v", err) log.Errorf("failed stopping engine %v", err)
return wrapErr(err) return wrapErr(err)
@ -276,7 +281,7 @@ func runClient(
return nil return nil
} }
statusRecorder.ClientStart() c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff) err = backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
@ -288,6 +293,14 @@ func runClient(
return nil return nil
} }
func (c *ConnectClient) Engine() *Engine {
var e *Engine
c.engineMutex.Lock()
e = c.engine
c.engineMutex.Unlock()
return e
}
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
engineConf := &EngineConfig{ engineConf := &EngineConfig{

View File

@ -174,6 +174,9 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
defer m.mux.Unlock() defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks) networks = m.routeSelector.FilterSelected(networks)
m.notifier.onNewRoutes(networks)
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)
for id, routes := range networks { for id, routes := range networks {

View File

@ -1,6 +1,7 @@
package routemanager package routemanager
import ( import (
"runtime"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -45,8 +46,15 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) {
} }
sort.Strings(newNets) sort.Strings(newNets)
if !n.hasDiff(n.initialRouteRanges, newNets) { switch runtime.GOOS {
return case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets

View File

@ -1,6 +1,7 @@
package stdnet package stdnet
import ( import (
"runtime"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -19,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool {
} }
for _, s := range disallowList { for _, s := range disallowList {
if strings.HasPrefix(iFace, s) { if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" {
log.Tracef("ignoring interface %s - it is not allowed", iFace) log.Tracef("ignoring interface %s - it is not allowed", iFace)
return false return false
} }

View File

@ -2,10 +2,15 @@ package NetBirdSDK
import ( import (
"context" "context"
"fmt"
"net/netip"
"sort"
"strings"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
@ -14,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@ -38,6 +44,12 @@ type CustomLogger interface {
Error(message string) Error(message string)
} }
type selectRoute struct {
NetID string
Network netip.Prefix
Selected bool
}
func init() { func init() {
formatter.SetLogcatFormatter(log.StandardLogger()) formatter.SetLogcatFormatter(log.StandardLogger())
} }
@ -55,6 +67,7 @@ type Client struct {
onHostDnsFn func([]string) onHostDnsFn func([]string)
dnsManager dns.IosDnsManager dnsManager dns.IosDnsManager
loginComplete bool loginComplete bool
connectClient *internal.ConnectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@ -107,7 +120,9 @@ func (c *Client) Run(fd int32, interfaceName string) error {
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {} c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName cfg.WgIface = interfaceName
return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.networkChangeListener, c.dnsManager)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@ -133,10 +148,29 @@ func (c *Client) GetStatusDetails() *StatusDetails {
peerInfos := make([]PeerInfo, len(fullStatus.Peers)) peerInfos := make([]PeerInfo, len(fullStatus.Peers))
for n, p := range fullStatus.Peers { for n, p := range fullStatus.Peers {
var routes = RoutesDetails{}
for r := range p.GetRoutes() {
routeInfo := RoutesInfo{r}
routes.items = append(routes.items, routeInfo)
}
pi := PeerInfo{ pi := PeerInfo{
p.IP, IP: p.IP,
p.FQDN, FQDN: p.FQDN,
p.ConnStatus.String(), LocalIceCandidateEndpoint: p.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: p.RemoteIceCandidateEndpoint,
LocalIceCandidateType: p.LocalIceCandidateType,
RemoteIceCandidateType: p.RemoteIceCandidateType,
PubKey: p.PubKey,
Latency: formatDuration(p.Latency),
BytesRx: p.BytesRx,
BytesTx: p.BytesTx,
ConnStatus: p.ConnStatus.String(),
ConnStatusUpdate: p.ConnStatusUpdate.Format("2006-01-02 15:04:05"),
Direct: p.Direct,
LastWireguardHandshake: p.LastWireguardHandshake.String(),
Relayed: p.Relayed,
RosenpassEnabled: p.RosenpassEnabled,
Routes: routes,
} }
peerInfos[n] = pi peerInfos[n] = pi
} }
@ -223,3 +257,142 @@ func (c *Client) IsLoginComplete() bool {
func (c *Client) ClearLoginComplete() { func (c *Client) ClearLoginComplete() {
c.loginComplete = false c.loginComplete = false
} }
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
if c.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
routesMap := engine.GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
if len(rt) == 0 {
continue
}
route := &selectRoute{
NetID: string(id),
Network: rt[0].Network,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
}
sort.Slice(routes, func(i, j int) bool {
iPrefix := routes[i].Network.Bits()
jPrefix := routes[j].Network.Bits()
if iPrefix == jPrefix {
iAddr := routes[i].Network.Addr()
jAddr := routes[j].Network.Addr()
if iAddr == jAddr {
return routes[i].NetID < routes[j].NetID
}
return iAddr.String() < jAddr.String()
}
return iPrefix < jPrefix
})
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Selected: r.Selected,
})
}
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails, nil
}
func (c *Client) SelectRoute(id string) error {
if c.connectClient == nil {
return fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
routeManager := engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if id == "All" {
log.Debugf("select all routes")
routeSelector.SelectAllRoutes()
} else {
log.Debugf("select route with id: %s", id)
routes := toNetIDs([]string{id})
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when selecting routes: %s", err)
return fmt.Errorf("select routes: %w", err)
}
}
routeManager.TriggerSelection(engine.GetClientRoutes())
return nil
}
func (c *Client) DeselectRoute(id string) error {
if c.connectClient == nil {
return fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
routeManager := engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector()
if id == "All" {
log.Debugf("deselect all routes")
routeSelector.DeselectAllRoutes()
} else {
log.Debugf("deselect route with id: %s", id)
routes := toNetIDs([]string{id})
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when deselecting routes: %s", err)
return fmt.Errorf("deselect routes: %w", err)
}
}
routeManager.TriggerSelection(engine.GetClientRoutes())
return nil
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")
if dotIndex != -1 {
// Determine end of numeric part, ensuring we stop at two decimal places or the actual end if fewer
endIndex := dotIndex + 3
if endIndex > len(ds) {
endIndex = len(ds)
}
// Find where the numeric part ends by finding the first non-digit character after the dot
unitStart := endIndex
for unitStart < len(ds) && (ds[unitStart] >= '0' && ds[unitStart] <= '9') {
unitStart++
}
// Ensures that we only take the unit characters after the numerical part
if unitStart < len(ds) {
return ds[:endIndex] + ds[unitStart:]
}
return ds[:endIndex] // In case no units are found after the digits
}
return ds
}
func toNetIDs(routes []string) []route.NetID {
var netIDs []route.NetID
for _, rt := range routes {
netIDs = append(netIDs, route.NetID(rt))
}
return netIDs
}

View File

@ -2,9 +2,28 @@ package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage // PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct { type PeerInfo struct {
IP string IP string
FQDN string FQDN string
ConnStatus string // Todo replace to enum LocalIceCandidateEndpoint string
RemoteIceCandidateEndpoint string
LocalIceCandidateType string
RemoteIceCandidateType string
PubKey string
Latency string
BytesRx int64
BytesTx int64
ConnStatus string
ConnStatusUpdate string
Direct bool
LastWireguardHandshake string
Relayed bool
RosenpassEnabled bool
Routes RoutesDetails
}
// GetRoutes return with RouteDetails
func (p PeerInfo) GetRouteDetails() *RoutesDetails {
return &p.Routes
} }
// PeerInfoCollection made for Java layer to get non default types as collection // PeerInfoCollection made for Java layer to get non default types as collection
@ -16,6 +35,21 @@ type PeerInfoCollection interface {
GetIP() string GetIP() string
} }
// RoutesInfoCollection made for Java layer to get non default types as collection
type RoutesInfoCollection interface {
Add(s string) RoutesInfoCollection
Get(i int) string
Size() int
}
type RoutesDetails struct {
items []RoutesInfo
}
type RoutesInfo struct {
Route string
}
// StatusDetails is the implementation of the PeerInfoCollection // StatusDetails is the implementation of the PeerInfoCollection
type StatusDetails struct { type StatusDetails struct {
items []PeerInfo items []PeerInfo
@ -23,6 +57,22 @@ type StatusDetails struct {
ip string ip string
} }
// Add new PeerInfo to the collection
func (array RoutesDetails) Add(s RoutesInfo) RoutesDetails {
array.items = append(array.items, s)
return array
}
// Get return an element of the collection
func (array RoutesDetails) Get(i int) *RoutesInfo {
return &array.items[i]
}
// Size return with the size of the collection
func (array RoutesDetails) Size() int {
return len(array.items)
}
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array StatusDetails) Add(s PeerInfo) StatusDetails { func (array StatusDetails) Add(s PeerInfo) StatusDetails {
array.items = append(array.items, s) array.items = append(array.items, s)

View File

@ -0,0 +1,36 @@
package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection
type RoutesSelectionInfoCollection interface {
Add(s string) RoutesSelectionInfoCollection
Get(i int) string
Size() int
}
type RoutesSelectionDetails struct {
All bool
Append bool
items []RoutesSelectionInfo
}
type RoutesSelectionInfo struct {
ID string
Network string
Selected bool
}
// Add new PeerInfo to the collection
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
array.items = append(array.items, s)
return array
}
// Get return an element of the collection
func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
return &array.items[i]
}
// Size return with the size of the collection
func (array RoutesSelectionDetails) Size() int {
return len(array.items)
}

View File

@ -23,12 +23,17 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.engine == nil { if s.connectClient == nil {
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }
routesMap := s.engine.GetClientRoutesWithNetID() engine := s.connectClient.Engine()
routeSelector := s.engine.GetRouteManager().GetRouteSelector() if engine == nil {
return nil, fmt.Errorf("not connected")
}
routesMap := engine.GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute var routes []*selectRoute
for id, rt := range routesMap { for id, rt := range routesMap {
@ -77,17 +82,26 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest)
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager() if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
routeManager := engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector() routeSelector := routeManager.GetRouteSelector()
if req.GetAll() { if req.GetAll() {
routeSelector.SelectAllRoutes() routeSelector.SelectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetRouteIDs()) routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("select routes: %w", err) return nil, fmt.Errorf("select routes: %w", err)
} }
} }
routeManager.TriggerSelection(s.engine.GetClientRoutes()) routeManager.TriggerSelection(engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil return &proto.SelectRoutesResponse{}, nil
} }
@ -97,17 +111,26 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
routeManager := s.engine.GetRouteManager() if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
routeManager := engine.GetRouteManager()
routeSelector := routeManager.GetRouteSelector() routeSelector := routeManager.GetRouteSelector()
if req.GetAll() { if req.GetAll() {
routeSelector.DeselectAllRoutes() routeSelector.DeselectAllRoutes()
} else { } else {
routes := toNetIDs(req.GetRouteIDs()) routes := toNetIDs(req.GetRouteIDs())
if err := routeSelector.DeselectRoutes(routes, maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
return nil, fmt.Errorf("deselect routes: %w", err) return nil, fmt.Errorf("deselect routes: %w", err)
} }
} }
routeManager.TriggerSelection(s.engine.GetClientRoutes()) routeManager.TriggerSelection(engine.GetClientRoutes())
return &proto.SelectRoutesResponse{}, nil return &proto.SelectRoutesResponse{}, nil
} }

View File

@ -57,7 +57,7 @@ type Server struct {
config *internal.Config config *internal.Config
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
engine *internal.Engine connectClient *internal.ConnectClient
statusRecorder *peer.Status statusRecorder *peer.Status
sessionWatcher *internal.SessionWatcher sessionWatcher *internal.SessionWatcher
@ -143,11 +143,8 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire) s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
} }
engineChan := make(chan *internal.Engine, 1)
go s.watchEngine(ctx, engineChan)
if !config.DisableAutoConnect { if !config.DisableAutoConnect {
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan) go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
} }
return nil return nil
@ -158,7 +155,6 @@ func (s *Server) Start() error {
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
engineChan chan<- *internal.Engine,
) { ) {
backOff := getConnectWithBackoff(ctx) backOff := getConnectWithBackoff(ctx)
retryStarted := false retryStarted := false
@ -188,7 +184,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error { runOperation := func() error {
log.Tracef("running client connection") log.Tracef("running client connection")
err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan) s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe)
if err != nil { if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err) log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
} }
@ -573,10 +570,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
engineChan := make(chan *internal.Engine, 1) go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
go s.watchEngine(ctx, engineChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan)
return &proto.UpResponse{}, nil return &proto.UpResponse{}, nil
} }
@ -593,8 +587,6 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
s.engine = nil
return &proto.DownResponse{}, nil return &proto.DownResponse{}, nil
} }
@ -688,22 +680,6 @@ func (s *Server) onSessionExpire() {
} }
} }
// watchEngine watches the engine channel and updates the engine state
func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) {
log.Tracef("Started watching engine")
for {
select {
case <-ctx.Done():
s.engine = nil
log.Tracef("Stopped watching engine")
return
case engine := <-engineChan:
log.Tracef("Received engine from watcher")
s.engine = engine
}
}
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{ pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{}, ManagementState: &proto.ManagementState{},

View File

@ -70,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s") t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1") t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil) s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
if counter < 3 { if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter) t.Fatalf("expected counter > 2, got %d", counter)
} }