diff --git a/client/android/client.go b/client/android/client.go index 1d7aa2ba5..e5dd9db9f 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" @@ -29,6 +30,11 @@ type IFaceDiscover interface { stdnet.ExternalIFaceDiscover } +// RouteListener export internal RouteListener for mobile +type RouteListener interface { + routemanager.RouteListener +} + func init() { formatter.SetLogcatFormatter(log.StandardLogger()) } @@ -42,10 +48,11 @@ type Client struct { ctxCancel context.CancelFunc ctxCancelLock *sync.Mutex deviceName string + routeListener routemanager.RouteListener } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client { +func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, routeListener RouteListener) *Client { lvl, _ := log.ParseLevel("trace") log.SetLevel(lvl) @@ -56,6 +63,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover iFaceDiscover: iFaceDiscover, recorder: peer.NewRecorder(""), ctxCancelLock: &sync.Mutex{}, + routeListener: routeListener, } } @@ -85,7 +93,7 @@ func (c *Client) Run(urlOpener URLOpener) error { // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover) + return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener) } // Stop the internal client and free the resources diff --git a/client/cmd/up.go b/client/cmd/up.go index 375832a78..3ebe1ce4b 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) SetupCloseHandler(ctx, cancel) - return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil) + return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil) } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { diff --git a/client/internal/connect.go b/client/internal/connect.go index 04f92bfac..91c0e5d76 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -13,6 +13,7 @@ import ( gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -23,7 +24,7 @@ import ( ) // RunClient with main logic. -func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) error { +func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error { backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -150,10 +151,11 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, return wrapErr(err) } - md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient) - if err != nil { - log.Error(err) - return wrapErr(err) + // in case of non Android os these variables will be nil + md := MobileDependency{ + TunAdapter: tunAdapter, + IFaceDiscover: iFaceDiscover, + RouteListener: routeListener, } engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 9227bf0c2..260ac7d67 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -208,7 +208,7 @@ func TestUpdateDNSServer(t *testing.T) { if err != nil { t.Fatal(err) } - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet) + wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 5adb645d3..e0e067dd2 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -182,12 +182,20 @@ func (e *Engine) Start() error { if err != nil { log.Errorf("failed to create pion's stdnet: %s", err) } - e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet) + + e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.TunAdapter, transportNet) if err != nil { log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error()) return err } + routes, err := e.readInitialRoutes() + if err != nil { + return err + } + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes) + e.routeManager.SetRouteChangeListener(e.mobileDep.RouteListener) + err = e.wgInterface.Create() if err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", wgIFaceName, err.Error()) @@ -222,8 +230,6 @@ func (e *Engine) Start() error { e.udpMux = mux } - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) - if acl, err := acl.Create(e.wgInterface); err != nil { log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error()) } else { @@ -1021,6 +1027,19 @@ func (e *Engine) close() { } } +func (e *Engine) readInitialRoutes() ([]*route.Route, error) { + if runtime.GOOS != "android" { + return nil, nil + } + + routesResp, err := e.mgmClient.GetRoutes() + if err != nil { + return nil, err + } + return toRoutes(routesResp), nil + +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index c21e15cd9..4d67968f6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -213,11 +213,11 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet) + engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } @@ -567,7 +567,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -736,7 +736,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 19a86edd6..18742b4cc 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -1,6 +1,7 @@ package internal import ( + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/iface" ) @@ -9,5 +10,5 @@ import ( type MobileDependency struct { TunAdapter iface.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover - Routes []string + RouteListener routemanager.RouteListener } diff --git a/client/internal/mobile_dependency_android.go b/client/internal/mobile_dependency_android.go deleted file mode 100644 index 5a43a243f..000000000 --- a/client/internal/mobile_dependency_android.go +++ /dev/null @@ -1,29 +0,0 @@ -package internal - -import ( - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - mgm "github.com/netbirdio/netbird/management/client" -) - -func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) { - md := MobileDependency{ - TunAdapter: tunAdapter, - IFaceDiscover: ifaceDiscover, - } - err := md.readMap(mgmClient) - return md, err -} - -func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error { - routes, err := mgmClient.GetRoutes() - if err != nil { - return err - } - - d.Routes = make([]string, len(routes)) - for i, r := range routes { - d.Routes[i] = r.GetNetwork() - } - return nil -} diff --git a/client/internal/mobile_dependency_nonandroid.go b/client/internal/mobile_dependency_nonandroid.go deleted file mode 100644 index d7ed84262..000000000 --- a/client/internal/mobile_dependency_nonandroid.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !android - -package internal - -import ( - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - mgm "github.com/netbirdio/netbird/management/client" -) - -func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) { - return MobileDependency{}, nil -} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 3c0342191..840d74269 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -16,6 +16,7 @@ import ( // Manager is a route manager interface type Manager interface { UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error + SetRouteChangeListener(listener RouteListener) Stop() } @@ -29,12 +30,14 @@ type DefaultManager struct { statusRecorder *peer.Status wgInterface *iface.WGIface pubKey string + notifier *notifier } // NewManager returns a new route manager -func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager { +func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) - return &DefaultManager{ + + dm := &DefaultManager{ ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), @@ -42,13 +45,25 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, + notifier: newNotifier(), } + + if runtime.GOOS == "android" { + cr := dm.clientRoutes(initialRoutes) + dm.notifier.setInitialClientRoutes(cr) + networks := readRouteNetworks(cr) + + // make sense to call before create interface + wgInterface.SetInitialRoutes(networks) + } + return dm } // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop() { m.stop() m.serverRouter.cleanUp() + m.ctx = nil } // UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps @@ -61,39 +76,10 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro m.mux.Lock() defer m.mux.Unlock() - newClientRoutesIDMap := make(map[string][]*route.Route) - newServerRoutesMap := make(map[string]*route.Route) - ownNetworkIDs := make(map[string]bool) - - for _, newRoute := range newRoutes { - networkID := route.GetHAUniqueID(newRoute) - if newRoute.Peer == m.pubKey { - ownNetworkIDs[networkID] = true - // only linux is supported for now - if runtime.GOOS != "linux" { - log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) - continue - } - newServerRoutesMap[newRoute.ID] = newRoute - } - } - - for _, newRoute := range newRoutes { - networkID := route.GetHAUniqueID(newRoute) - if !ownNetworkIDs[networkID] { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < 7 { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", - version.NetbirdVersion(), newRoute.Network) - continue - } - newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) - } - } + newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes) m.updateClientNetworks(updateSerial, newClientRoutesIDMap) - + m.notifier.onNewRoutes(newClientRoutesIDMap) err := m.serverRouter.updateRoutes(newServerRoutesMap) if err != nil { return err @@ -103,6 +89,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro } } +// SetRouteChangeListener set RouteListener for route change notifier +func (m *DefaultManager) SetRouteChangeListener(listener RouteListener) { + m.notifier.setListener(listener) +} + func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { // removing routes that do not exist as per the update from the Management service. for id, client := range m.clientNetworks { @@ -128,3 +119,55 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[ clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) } } + +func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) { + newClientRoutesIDMap := make(map[string][]*route.Route) + newServerRoutesMap := make(map[string]*route.Route) + ownNetworkIDs := make(map[string]bool) + + for _, newRoute := range newRoutes { + networkID := route.GetHAUniqueID(newRoute) + if newRoute.Peer == m.pubKey { + ownNetworkIDs[networkID] = true + // only linux is supported for now + if runtime.GOOS != "linux" { + log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) + continue + } + newServerRoutesMap[newRoute.ID] = newRoute + } + } + + for _, newRoute := range newRoutes { + networkID := route.GetHAUniqueID(newRoute) + if !ownNetworkIDs[networkID] { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < 7 { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", + version.NetbirdVersion(), newRoute.Network) + continue + } + newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute) + } + } + + return newServerRoutesMap, newClientRoutesIDMap +} + +func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { + _, crMap := m.classifiesRoutes(initialRoutes) + rs := make([]*route.Route, 0) + for _, routes := range crMap { + rs = append(rs, routes...) + } + return rs +} + +func readRouteNetworks(cr []*route.Route) []string { + routesNetworks := make([]string, 0) + for _, r := range cr { + routesNetworks = append(routesNetworks, r.Network.String()) + } + return routesNetworks +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index ca42cc40c..6291b4996 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -397,7 +397,7 @@ func TestManagerUpdateRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -406,7 +406,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder) + routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) defer routeManager.Stop() if len(testCase.inputInitRoutes) > 0 { diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 4d9a714d3..bd619a1c8 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -1,7 +1,10 @@ package routemanager import ( + "context" "fmt" + + "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -11,6 +14,11 @@ type MockManager struct { StopFunc func() } +// InitialClientRoutesNetworks mock implementation of InitialClientRoutesNetworks from Manager interface +func (m *MockManager) InitialClientRoutesNetworks() []string { + return nil +} + // UpdateRoutes mock implementation of UpdateRoutes from Manager interface func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { if m.UpdateRoutesFunc != nil { @@ -19,6 +27,15 @@ func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route return fmt.Errorf("method UpdateRoutes is not implemented") } +// Start mock implementation of Start from Manager interface +func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { +} + +// SetRouteChangeListener mock implementation of SetRouteChangeListener from Manager interface +func (m *MockManager) SetRouteChangeListener(listener RouteListener) { + +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop() { if m.StopFunc != nil { diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go new file mode 100644 index 000000000..2d1afa055 --- /dev/null +++ b/client/internal/routemanager/notifier.go @@ -0,0 +1,86 @@ +package routemanager + +import ( + "sort" + "sync" + + "github.com/netbirdio/netbird/route" +) + +// RouteListener is a callback interface for mobile system +type RouteListener interface { + // OnNewRouteSetting invoke when new route setting has been arrived + OnNewRouteSetting() +} + +type notifier struct { + initialRouteRangers []string + routeRangers []string + + routeListener RouteListener + routeListenerMux sync.Mutex +} + +func newNotifier() *notifier { + return ¬ifier{} +} + +func (n *notifier) setListener(listener RouteListener) { + n.routeListenerMux.Lock() + defer n.routeListenerMux.Unlock() + n.routeListener = listener +} + +func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) { + nets := make([]string, 0) + for _, r := range clientRoutes { + nets = append(nets, r.Network.String()) + } + sort.Strings(nets) + n.initialRouteRangers = nets +} + +func (n *notifier) onNewRoutes(idMap map[string][]*route.Route) { + newNets := make([]string, 0) + for _, routes := range idMap { + for _, r := range routes { + newNets = append(newNets, r.Network.String()) + } + } + + sort.Strings(newNets) + if !n.hasDiff(n.routeRangers, newNets) { + return + } + + n.routeRangers = newNets + + if !n.hasDiff(n.initialRouteRangers, newNets) { + return + } + n.notify() +} + +func (n *notifier) notify() { + n.routeListenerMux.Lock() + defer n.routeListenerMux.Unlock() + if n.routeListener == nil { + return + } + + go func(l RouteListener) { + l.OnNewRouteSetting() + }(n.routeListener) +} + +func (n *notifier) hasDiff(a []string, b []string) bool { + if len(a) != len(b) { + return true + } + for i, v := range a { + if v != b[i] { + return true + } + } + return false +} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index f93509780..59d4cb72c 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -37,7 +37,7 @@ func TestAddRemoveRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet) + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/client/server/server.go b/client/server/server.go index 44a1acebf..5743e57ed 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -102,7 +102,7 @@ func (s *Server) Start() error { } go func() { - if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil { + if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil { log.Errorf("init connections: %v", err) } }() @@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } go func() { - if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil { + if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil { log.Errorf("run client connection: %v", err) return } diff --git a/iface/iface_android.go b/iface/iface_android.go index 52bbc2ed4..8b6e55f96 100644 --- a/iface/iface_android.go +++ b/iface/iface_android.go @@ -7,7 +7,7 @@ import ( ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { +func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { wgIFace := &WGIface{ mu: sync.Mutex{}, } @@ -17,7 +17,7 @@ func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunA return wgIFace, err } - tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet) + tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet) wgIFace.tun = tun wgIFace.configurer = newWGConfigurer(tun) @@ -26,3 +26,8 @@ func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunA return wgIFace, nil } + +// SetInitialRoutes store the given routes and on the tun creation will be used +func (w *WGIface) SetInitialRoutes(routes []string) { + w.tun.SetRoutes(routes) +} diff --git a/iface/iface_nonandroid.go b/iface/iface_nonandroid.go index 2e8233bd6..fca7059f0 100644 --- a/iface/iface_nonandroid.go +++ b/iface/iface_nonandroid.go @@ -9,7 +9,7 @@ import ( ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { +func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { wgIFace := &WGIface{ mu: sync.Mutex{}, } @@ -25,3 +25,8 @@ func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunA wgIFace.userspaceBind = !WireGuardModuleIsLoaded() return wgIFace, nil } + +// SetInitialRoutes unused function on non Android +func (w *WGIface) SetInitialRoutes(routes []string) { + +} diff --git a/iface/iface_test.go b/iface/iface_test.go index 7ac8c2179..3e0759d87 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -39,7 +39,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet) + iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -103,7 +103,7 @@ func Test_CreateInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -136,7 +136,7 @@ func Test_Close(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -168,7 +168,7 @@ func Test_ConfigureInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -219,7 +219,7 @@ func Test_UpdatePeer(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -282,7 +282,7 @@ func Test_RemovePeer(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet) + iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -335,7 +335,7 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet) + iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } @@ -356,7 +356,7 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet) + iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet) if err != nil { t.Fatal(err) } diff --git a/iface/tun_android.go b/iface/tun_android.go index 93cfe522c..9f4f6e192 100644 --- a/iface/tun_android.go +++ b/iface/tun_android.go @@ -25,16 +25,19 @@ type tunDevice struct { wrapper *DeviceWrapper } -func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { +func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { return &tunDevice{ address: address, mtu: mtu, - routes: routes, tunAdapter: tunAdapter, iceBind: bind.NewICEBind(transportNet), } } +func (t *tunDevice) SetRoutes(routes []string) { + t.routes = routes +} + func (t *tunDevice) Create() error { var err error routesString := t.routesToString() diff --git a/management/client/client.go b/management/client/client.go index 3ca79d3b6..2f903d210 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -15,4 +15,5 @@ type Client interface { Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) + GetRoutes() ([]*proto.Route, error) } diff --git a/management/client/mock.go b/management/client/mock.go index f81b0b30c..589a4f784 100644 --- a/management/client/mock.go +++ b/management/client/mock.go @@ -56,3 +56,8 @@ func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D } return m.GetDeviceAuthorizationFlowFunc(serverKey) } + +// GetRoutes mock implementation of GetRoutes from mgm.Client interface +func (m *MockClient) GetRoutes() ([]*proto.Route, error) { + return nil, nil +} diff --git a/signal/client/grpc.go b/signal/client/grpc.go index f1106926c..08430e8ef 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -366,7 +366,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, } else if err != nil { return err } - log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key) + log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key) decryptedMessage, err := c.decryptMessage(msg) if err != nil {