Add route management for Android interface (#801)

Support client route management feature on Android
This commit is contained in:
Zoltan Papp 2023-04-17 11:15:37 +02:00 committed by GitHub
parent 1803cf3678
commit 4616bc5258
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 486 additions and 355 deletions

View File

@ -144,13 +144,19 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover) engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, statusRecorder) md, err := newMobileDependency(tunAdapter, iFaceDiscover, mgmClient)
if err != nil {
log.Error(err)
return wrapErr(err)
}
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
err = engine.Start() err = engine.Start()
if err != nil { if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
@ -194,13 +200,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
} }
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
TunAdapter: tunAdapter,
IFaceDiscover: iFaceDiscover,
IFaceBlackList: config.IFaceBlackList, IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key, WgPrivateKey: key,

View File

@ -206,7 +206,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, newNet) wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), iface.DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@ -47,10 +46,6 @@ var ErrResetConnection = fmt.Errorf("reset connection")
type EngineConfig struct { type EngineConfig struct {
WgPort int WgPort int
WgIfaceName string WgIfaceName string
// TunAdapter is option. It is necessary for mobile version.
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
// WgAddr is a Wireguard local address (Netbird Network IP) // WgAddr is a Wireguard local address (Netbird Network IP)
WgAddr string WgAddr string
@ -90,7 +85,9 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing // syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex syncMsgMux *sync.Mutex
config *EngineConfig config *EngineConfig
mobileDep MobileDependency
// STUNs is a list of STUN servers used by ICE // STUNs is a list of STUN servers used by ICE
STUNs []*ice.URL STUNs []*ice.URL
// TURNs is a list of STUN servers used by ICE // TURNs is a list of STUN servers used by ICE
@ -130,7 +127,7 @@ type Peer struct {
func NewEngine( func NewEngine(
ctx context.Context, cancel context.CancelFunc, ctx context.Context, cancel context.CancelFunc,
signalClient signal.Client, mgmClient mgm.Client, signalClient signal.Client, mgmClient mgm.Client,
config *EngineConfig, statusRecorder *peer.Status, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
ctx: ctx, ctx: ctx,
@ -140,6 +137,7 @@ func NewEngine(
peerConns: make(map[string]*peer.Conn), peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
mobileDep: mobileDep,
STUNs: []*ice.URL{}, STUNs: []*ice.URL{},
TURNs: []*ice.URL{}, TURNs: []*ice.URL{},
networkSerial: 0, networkSerial: 0,
@ -181,7 +179,7 @@ func (e *Engine) Start() error {
if err != nil { if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.config.TunAdapter, transportNet) e.wgInterface, err = iface.NewWGIFace(wgIFaceName, wgAddr, iface.DefaultMTU, e.mobileDep.Routes, e.mobileDep.TunAdapter, transportNet)
if err != nil { if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error()) log.Errorf("failed creating wireguard interface instance %s: [%s]", wgIFaceName, err.Error())
return err return err
@ -834,7 +832,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
UserspaceBind: e.wgInterface.IsUserspaceBind(), UserspaceBind: e.wgInterface.IsUserspaceBind(),
} }
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover) peerConn, err := peer.NewConn(config, e.statusRecorder, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet" import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) { func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.config.IFaceDiscover, e.config.IFaceBlackList) return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
} }

View File

@ -74,7 +74,7 @@ func TestEngine_SSH(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -208,12 +208,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, newNet) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -404,7 +404,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -562,12 +562,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@ -731,12 +731,12 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"))
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, newNet) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU, nil, nil, newNet)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
@ -1000,7 +1000,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
} }
func startSignal() (*grpc.Server, string, error) { func startSignal() (*grpc.Server, string, error) {

View File

@ -0,0 +1,13 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
)
// MobileDependency collect all dependencies for mobile platform
type MobileDependency struct {
TunAdapter iface.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
Routes []string
}

View File

@ -0,0 +1,29 @@
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
md := MobileDependency{
TunAdapter: tunAdapter,
IFaceDiscover: ifaceDiscover,
}
err := md.readMap(mgmClient)
return md, err
}
func (d *MobileDependency) readMap(mgmClient *mgm.GrpcClient) error {
routes, err := mgmClient.GetRoutes()
if err != nil {
return err
}
d.Routes = make([]string, len(routes))
for i, r := range routes {
d.Routes[i] = r.GetNetwork()
}
return nil
}

View File

@ -0,0 +1,13 @@
//go:build !android
package internal
import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
)
func newMobileDependency(tunAdapter iface.TunAdapter, ifaceDiscover stdnet.ExternalIFaceDiscover, mgmClient *mgm.GrpcClient) (MobileDependency, error) {
return MobileDependency{}, nil
}

View File

@ -1,12 +1,15 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
import "github.com/google/nftables"
const ( const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding" ipv6Forwarding = "netbird-rt-ipv6-forwarding"

View File

@ -1,14 +1,17 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"net/netip" "net/netip"
"os/exec" "os/exec"
"strings" "strings"
"sync" "sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {

View File

@ -1,10 +1,13 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@ -1,9 +1,130 @@
package routemanager package routemanager
import "github.com/netbirdio/netbird/route" import (
"context"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
Stop() Stop()
} }
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRouter: newServerRouter(ctx, wgInterface),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.cleanUp()
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}

View File

@ -1,31 +0,0 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// DefaultManager dummy router manager for Android
type DefaultManager struct {
ctx context.Context
serverRouter *serverRouter
wgInterface *iface.WGIface
}
// NewManager returns a new dummy route manager what doing nothing
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
return &DefaultManager{}
}
// UpdateRoutes ...
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
return nil
}
// Stop ...
func (m *DefaultManager) Stop() {
}

View File

@ -1,186 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/version"
)
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRoutes map[string]*route.Route
serverRouter *serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRoutes: make(map[string]*route.Route),
serverRouter: &serverRouter{
routes: make(map[string]*route.Route),
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
firewall: NewFirewall(ctx),
},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.firewall.CleanRoutingRules()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.serverRouter.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.serverRoutes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
continue
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.serverRoutes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.serverRoutes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.serverRoutes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.serverRoutes[id] = newRoute
}
if len(m.serverRoutes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if newRoute.Peer == m.pubKey {
ownNetworkIDs[networkID] = true
// only linux is supported for now
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
}
}
for _, newRoute := range newRoutes {
networkID := route.GetHAUniqueID(newRoute)
if !ownNetworkIDs[networkID] {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
version.NetbirdVersion(), newRoute.Network)
continue
}
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.updateServerRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}

View File

@ -392,11 +392,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases { for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, newNet) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU, nil, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@ -419,7 +420,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes { if testCase.shouldCheckServerRoutes {
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match") require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
} }
}) })
} }

View File

@ -1,16 +1,19 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
) )
import "github.com/google/nftables"
const ( const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"

View File

@ -1,12 +1,15 @@
//go:build !android
package routemanager package routemanager
import ( import (
"context" "context"
"testing"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {

View File

@ -0,0 +1,24 @@
package routemanager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}

View File

@ -1,67 +0,0 @@
package routemanager
import (
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"net/netip"
"sync"
)
type serverRouter struct {
routes map[string]*route.Route
// best effort to keep net forward configuration as it was
netForwardHistoryEnabled bool
mux sync.Mutex
firewall firewallManager
}
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.serverRouter.routes, route.ID)
return nil
}
}
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.serverRouter.routes[route.ID] = route
return nil
}
}

View File

@ -0,0 +1,21 @@
package routemanager
import (
"context"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{}
}
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
return nil
}
func (r *serverRouter) cleanUp() {}

View File

@ -0,0 +1,120 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewallManager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
return &serverRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
firewall: NewFirewall(ctx),
wgInterface: wgInterface,
}
}
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.routes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.routes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.routes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.routes[id] = newRoute
}
if len(m.routes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
delete(m.routes, route.ID)
return nil
}
}
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
if err != nil {
return err
}
m.routes[route.ID] = route
return nil
}
}
func (m *serverRouter) cleanUp() {
m.firewall.CleanRoutingRules()
}

View File

@ -0,0 +1,13 @@
package routemanager
import (
"net/netip"
)
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
return nil
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
return nil
}

View File

@ -1,10 +1,13 @@
//go:build !android
package routemanager package routemanager
import ( import (
"github.com/vishvananda/netlink"
"net" "net"
"net/netip" "net/netip"
"os" "os"
"github.com/vishvananda/netlink"
) )
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
@ -62,12 +65,3 @@ func enableIPForwarding() error {
err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
return err return err
} }
func isNetForwardHistoryEnabled() bool {
out, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
// todo
panic(err)
}
return string(out) == "1"
}

View File

@ -1,11 +1,14 @@
//go:build !android
package routemanager package routemanager
import ( import (
"fmt" "fmt"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"net" "net"
"net/netip" "net/netip"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
) )
var errRouteNotFound = fmt.Errorf("route not found") var errRouteNotFound = fmt.Errorf("route not found")

View File

@ -37,7 +37,7 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, newNet) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU, nil, nil, newNet)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@ -4,10 +4,11 @@
package routemanager package routemanager
import ( import (
log "github.com/sirupsen/logrus"
"net/netip" "net/netip"
"os/exec" "os/exec"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
func addToRouteTable(prefix netip.Prefix, addr string) error { func addToRouteTable(prefix netip.Prefix, addr string) error {
@ -34,8 +35,3 @@ func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil return nil
} }
func isNetForwardHistoryEnabled() bool {
log.Infof("check netforward history is not implemented on %s", runtime.GOOS)
return false
}

View File

@ -7,7 +7,7 @@ import (
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { func NewWGIFace(ifaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }
@ -17,7 +17,7 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter
return wgIFace, err return wgIFace, err
} }
tun := newTunDevice(wgAddress, mtu, tunAdapter, transportNet) tun := newTunDevice(wgAddress, mtu, routes, tunAdapter, transportNet)
wgIFace.tun = tun wgIFace.tun = tun
wgIFace.configurer = newWGConfigurer(tun) wgIFace.configurer = newWGConfigurer(tun)

View File

@ -9,7 +9,7 @@ import (
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
mu: sync.Mutex{}, mu: sync.Mutex{},
} }

View File

@ -39,7 +39,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, addr, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -103,7 +103,7 @@ func Test_CreateInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -136,7 +136,7 @@ func Test_Close(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -168,7 +168,7 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -219,7 +219,7 @@ func Test_UpdatePeer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -282,7 +282,7 @@ func Test_RemovePeer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, newNet) iface, err := NewWGIFace(ifaceName, wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -335,7 +335,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, newNet) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -356,7 +356,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, newNet) iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, DefaultMTU, nil, nil, newNet)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -2,6 +2,6 @@ package iface
// TunAdapter is an interface for create tun device from externel service // TunAdapter is an interface for create tun device from externel service
type TunAdapter interface { type TunAdapter interface {
ConfigureInterface(address string, mtu int) (int, error) ConfigureInterface(address string, mtu int, routes string) (int, error)
UpdateAddr(address string) error UpdateAddr(address string) error
} }

View File

@ -2,6 +2,7 @@ package iface
import ( import (
"net" "net"
"strings"
"github.com/pion/transport/v2" "github.com/pion/transport/v2"
@ -17,6 +18,7 @@ import (
type tunDevice struct { type tunDevice struct {
address WGAddress address WGAddress
mtu int mtu int
routes []string
tunAdapter TunAdapter tunAdapter TunAdapter
fd int fd int
@ -26,10 +28,11 @@ type tunDevice struct {
iceBind *bind.ICEBind iceBind *bind.ICEBind
} }
func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice { func newTunDevice(address WGAddress, mtu int, routes []string, tunAdapter TunAdapter, transportNet transport.Net) *tunDevice {
return &tunDevice{ return &tunDevice{
address: address, address: address,
mtu: mtu, mtu: mtu,
routes: routes,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet),
} }
@ -37,7 +40,8 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
func (t *tunDevice) Create() error { func (t *tunDevice) Create() error {
var err error var err error
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu) routesString := t.routesToString()
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return err return err
@ -115,3 +119,7 @@ func (t *tunDevice) Close() (err error) {
return return
} }
func (t *tunDevice) routesToString() string {
return strings.Join(t.routes, ";")
}

View File

@ -172,6 +172,49 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil return nil
} }
// GetRoutes return with the routes
func (c *GrpcClient) GetRoutes() ([]*proto.Route, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap().GetRoutes(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{} req := &proto.SyncRequest{}