Feature/android route notification (#868)

Add new feature to notify the user when new client route has arrived.
Refactor the initial route handling. I move every route logic into the route
manager package.

* Add notification management for client rules
* Export the route notification for Android
* Compare the notification based on network range instead of id.
This commit is contained in:
Zoltan Papp 2023-05-31 18:25:24 +02:00 committed by GitHub
parent 6425eb6732
commit 45a6263adc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 264 additions and 111 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
@ -29,6 +30,11 @@ type IFaceDiscover interface {
stdnet.ExternalIFaceDiscover stdnet.ExternalIFaceDiscover
} }
// RouteListener export internal RouteListener for mobile
type RouteListener interface {
routemanager.RouteListener
}
func init() { func init() {
formatter.SetLogcatFormatter(log.StandardLogger()) formatter.SetLogcatFormatter(log.StandardLogger())
} }
@ -42,10 +48,11 @@ type Client struct {
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
routeListener routemanager.RouteListener
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client { func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client {
lvl, _ := log.ParseLevel("trace") lvl, _ := log.ParseLevel("trace")
log.SetLevel(lvl) log.SetLevel(lvl)
@ -56,6 +63,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""), recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{}, ctxCancelLock: &sync.Mutex{},
routeListener: routeListener,
} }
} }
@ -85,7 +93,7 @@ func (c *Client) Run(urlOpener URLOpener) error {
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover) return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil) return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {

View File

@ -13,6 +13,7 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -23,7 +24,7 @@ import (
) )
// RunClient with main logic. // RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error { func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@ -150,10 +151,11 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
return wrapErr(err) return wrapErr(err)
} }
md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient) // in case of non Android os these variables will be nil
if err != nil { md := MobileDependency{
log.Error(err) TunAdapter: tunAdapter,
return wrapErr(err) IFaceDiscover: iFaceDiscover,
RouteListener: routeListener,
} }
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder) engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)

View File

@ -208,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet) wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -182,12 +182,20 @@ func (e *Engine) Start() error {
if err != nil { if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.TunAdapter, transportNet)
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error()) log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
return err return err
} }
routes, err := e.readInitialRoutes()
if err != nil {
return err
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener)
err = e.wgInterface.Create() err = e.wgInterface.Create()
if err != nil { if err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error())
@ -222,8 +230,6 @@ func (e *Engine) Start() error {
e.udpMux = mux e.udpMux = mux
} }
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
if acl, err := acl.Create(e.wgInterface); err != nil { if acl, err := acl.Create(e.wgInterface); err != nil {
log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error()) log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error())
} else { } else {
@ -1021,6 +1027,19 @@ func (e *Engine) close() {
} }
} }
func (e *Engine) readInitialRoutes() ([]*route.Route, error) {
if runtime.GOOS != "android" {
return nil, nil
}
routesResp, err := e.mgmClient.GetRoutes()
if err != nil {
return nil, err
}
return toRoutes(routesResp), nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {

View File

@ -213,11 +213,11 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
@ -567,7 +567,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@ -736,7 +736,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{

View File

@ -1,6 +1,7 @@
package internal package internal
import ( import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@ -9,5 +10,5 @@ import (
type MobileDependency struct { type MobileDependency struct {
TunAdapter iface.TunAdapter TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
Routes []string RouteListener routemanager.RouteListener
} }

View File

@ -1,29 +0,0 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: ifaceDiscover,
}
err := md.readMap(mgmClient)
return md, err
}
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
routes, err := mgmClient.GetRoutes()
if err != nil {
return err
}
d.Routes = make([]string, len(routes))
for i, r := range routes {
d.Routes[i] = r.GetNetwork()
}
return nil
}

View File

@ -1,13 +0,0 @@
//go:build !android
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
return MobileDependency{}, nil
}

View File

@ -16,6 +16,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener RouteListener)
Stop() Stop()
} }
@ -29,12 +30,14 @@ type DefaultManager struct {
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
pubKey string pubKey string
notifier *notifier
} }
// NewManager returns a new route manager // NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager { func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[string]*clientNetwork),
@ -42,13 +45,25 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: newNotifier(),
} }
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
networks := readRouteNetworks(cr)
// make sense to call before create interface
wgInterface.SetInitialRoutes(networks)
}
return dm
} }
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() { func (m *DefaultManager) Stop() {
m.stop() m.stop()
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
m.ctx = nil
} }
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps // UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
@ -61,39 +76,10 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route) newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap) m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap) err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil { if err != nil {
return err return err
@ -103,6 +89,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
} }
} }
// SetRouteChangeListener set RouteListener for route change notifier
func (m *DefaultManager) SetRouteChangeListener(listener RouteListener) {
m.notifier.setListener(listener)
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service. // removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
@ -128,3 +119,55 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
} }
} }
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
return newServerRoutesMap, newClientRoutesIDMap
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifiesRoutes(initialRoutes)
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}
func readRouteNetworks(cr []*route.Route) []string {
routesNetworks := make([]string, 0)
for _, r := range cr {
routesNetworks = append(routesNetworks, r.Network.String())
}
return routesNetworks
}

View File

@ -397,7 +397,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@ -406,7 +406,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder) routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
defer routeManager.Stop() defer routeManager.Stop()
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {

View File

@ -1,7 +1,10 @@
package routemanager package routemanager
import ( import (
"context"
"fmt" "fmt"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -11,6 +14,11 @@ type MockManager struct {
StopFunc func() StopFunc func()
} }
// InitialClientRoutesNetworks mock implementation of InitialClientRoutesNetworks from Manager interface
func (m *MockManager) InitialClientRoutesNetworks() []string {
return nil
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
@ -19,6 +27,15 @@ func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route
return fmt.Errorf("method UpdateRoutes is not implemented") return fmt.Errorf("method UpdateRoutes is not implemented")
} }
// Start mock implementation of Start from Manager interface
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
}
// SetRouteChangeListener mock implementation of SetRouteChangeListener from Manager interface
func (m *MockManager) SetRouteChangeListener(listener RouteListener) {
}
// Stop mock implementation of Stop from Manager interface // Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop() { func (m *MockManager) Stop() {
if m.StopFunc != nil { if m.StopFunc != nil {

View File

@ -0,0 +1,86 @@
package routemanager
import (
"sort"
"sync"
"github.com/netbirdio/netbird/route"
)
// RouteListener is a callback interface for mobile system
type RouteListener interface {
// OnNewRouteSetting invoke when new route setting has been arrived
OnNewRouteSetting()
}
type notifier struct {
initialRouteRangers []string
routeRangers []string
routeListener RouteListener
routeListenerMux sync.Mutex
}
func newNotifier() *notifier {
return &notifier{}
}
func (n *notifier) setListener(listener RouteListener) {
n.routeListenerMux.Lock()
defer n.routeListenerMux.Unlock()
n.routeListener = listener
}
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
nets = append(nets, r.Network.String())
}
sort.Strings(nets)
n.initialRouteRangers = nets
}
func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) {
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
newNets = append(newNets, r.Network.String())
}
}
sort.Strings(newNets)
if !n.hasDiff(n.routeRangers, newNets) {
return
}
n.routeRangers = newNets
if !n.hasDiff(n.initialRouteRangers, newNets) {
return
}
n.notify()
}
func (n *notifier) notify() {
n.routeListenerMux.Lock()
defer n.routeListenerMux.Unlock()
if n.routeListener == nil {
return
}
go func(l RouteListener) {
l.OnNewRouteSetting()
}(n.routeListener)
}
func (n *notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
for i, v := range a {
if v != b[i] {
return true
}
}
return false
}

View File

@ -37,7 +37,7 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@ -102,7 +102,7 @@ func (s *Server) Start() error {
} }
go func() { go func() {
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil { if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil {
log.Errorf("init connections: %v", err) log.Errorf("init connections: %v", err)
} }
}() }()
@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
} }
go func() { go func() {
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil { if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil {
log.Errorf("run client connection: %v", err) log.Errorf("run client connection: %v", err)
return return
} }

View File

@ -7,7 +7,7 @@ import (
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
@ -17,7 +17,7 @@ func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunA
return wgIFace, err return wgIFace, err
} }
tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet) tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet)
wgIFace.tun = tun wgIFace.tun = tun
wgIFace.configurer = newWGConfigurer(tun) wgIFace.configurer = newWGConfigurer(tun)
@ -26,3 +26,8 @@ func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunA
return wgIFace, nil return wgIFace, nil
} }
// SetInitialRoutes store the given routes and on the tun creation will be used
func (w *WGIface) SetInitialRoutes(routes []string) {
w.tun.SetRoutes(routes)
}

View File

@ -9,7 +9,7 @@ import (
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
@ -25,3 +25,8 @@ func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunA
wgIFace.userspaceBind = !WireGuardModuleIsLoaded() wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
return wgIFace, nil return wgIFace, nil
} }
// SetInitialRoutes unused function on non Android
func (w *WGIface) SetInitialRoutes(routes []string) {
}

View File

@ -39,7 +39,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet) iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -103,7 +103,7 @@ func Test_CreateInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -136,7 +136,7 @@ func Test_Close(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -168,7 +168,7 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -219,7 +219,7 @@ func Test_UpdatePeer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -282,7 +282,7 @@ func Test_RemovePeer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -335,7 +335,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -356,7 +356,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet) iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -25,16 +25,19 @@ type tunDevice struct {
wrapper *DeviceWrapper wrapper *DeviceWrapper
} }
func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
address: address, address: address,
mtu: mtu, mtu: mtu,
routes: routes,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet),
} }
} }
func (t *tunDevice) SetRoutes(routes []string) {
t.routes = routes
}
func (t *tunDevice) Create() error { func (t *tunDevice) Create() error {
var err error var err error
routesString := t.routesToString() routesString := t.routesToString()

View File

@ -15,4 +15,5 @@ type Client interface {
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetRoutes() ([]*proto.Route, error)
} }

View File

@ -56,3 +56,8 @@ func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
} }
return m.GetDeviceAuthorizationFlowFunc(serverKey) return m.GetDeviceAuthorizationFlowFunc(serverKey)
} }
// GetRoutes mock implementation of GetRoutes from mgm.Client interface
func (m *MockClient) GetRoutes() ([]*proto.Route, error) {
return nil, nil
}

View File

@ -366,7 +366,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
} else if err != nil { } else if err != nil {
return err return err
} }
log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key) log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg) decryptedMessage, err := c.decryptMessage(msg)
if err != nil { if err != nil {