- Revert typos in turnCfg string

- merge main
This commit is contained in:
Zoltan Papp 2024-07-08 15:05:29 +02:00
commit b3715b5fad
286 changed files with 11781 additions and 6153 deletions

View File

@ -14,7 +14,7 @@ jobs:
test:
strategy:
matrix:
store: ['jsonfile', 'sqlite']
store: ['sqlite']
runs-on: macos-latest
steps:
- name: Install Go

View File

@ -0,0 +1,39 @@
name: Test Code FreeBSD
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
prepare: |
pkg install -y curl
pkg install -y git
run: |
set -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
tar zxf go.tar.gz
mv go /usr/local/go
ln -s /usr/local/go/bin/go /usr/local/bin/go
go mod tidy
go test -timeout 5m -p 1 ./iface/...
go test -timeout 5m -p 1 ./client/...
cd client
go build .
cd ..

View File

@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
arch: [ '386','amd64' ]
store: [ 'jsonfile', 'sqlite', 'postgres']
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-latest
steps:
- name: Install Go
@ -86,7 +86,10 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@ -108,6 +111,9 @@ jobs:
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@ -173,7 +173,7 @@ jobs:
retention-days: 3
release_ui_darwin:
runs-on: macos-11
runs-on: macos-latest
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV

View File

@ -178,34 +178,79 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- name: run script
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen
- name: test Caddy file gen postgres
run: test -f Caddyfile
- name: test docker-compose file gen
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml
- name: test management.json file gen
- name: test management.json file gen postgres
run: test -f management.json
- name: test turnserver.conf file gen
- name: test turnserver.conf file gen postgres
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen
- name: test zitadel.env file gen postgres
run: test -f zitadel.env
- name: test dashboard.env file gen
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

View File

@ -3,8 +3,10 @@ builds:
- id: netbird-ui-darwin
dir: client/ui
binary: netbird-ui
env: [CGO_ENABLED=1]
env:
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
goos:
- darwin
goarch:

View File

@ -1,4 +1,4 @@
FROM alpine:3.18.5
FROM alpine:3.19
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@ -36,6 +36,7 @@ const (
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
)
var (
@ -68,7 +69,9 @@ var (
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
rootCmd = &cobra.Command{
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
Use: "netbird",
Short: "",
Long: "",

View File

@ -2,6 +2,7 @@ package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@ -66,18 +67,60 @@ func routesList(cmd *cobra.Command, _ []string) error {
return nil
}
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {

View File

@ -807,11 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
}
for i, route := range peer.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
peer.Routes[i] = anonymizeRoute(a, route)
}
}
@ -847,12 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
for i, route := range overview.Routes {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
overview.Routes[i] = anonymizeRoute(a, route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@ -7,6 +7,9 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util"
@ -53,7 +56,10 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
srv, err := sig.NewServer(otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
go func() {
if err := s.Serve(lis); err != nil {
panic(err)
@ -70,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
t.Fatal(err)
}
@ -81,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}
@ -102,7 +108,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
}
func startClientDaemon(
t *testing.T, ctx context.Context, managementURL, configPath string,
t *testing.T, ctx context.Context, _, configPath string,
) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@ -7,11 +7,13 @@ import (
"net/netip"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
@ -40,8 +42,12 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
}
func upFunc(cmd *cobra.Command, args []string) error {
@ -137,6 +143,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@ -237,6 +247,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.NetworkMonitor = &networkMonitor
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error
var loginResp *proto.LoginResponse

30
client/errors/errors.go Normal file
View File

@ -0,0 +1,30 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
return nil
}
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error {
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}

View File

@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.InsertRoutingRules(pair)
return m.router.AddRoutingRules(pair)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {

View File

@ -22,6 +22,8 @@ const (
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
@ -126,6 +128,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
@ -138,28 +156,28 @@ func (r *router) createContainers() error {
return nil
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router
}
}
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,

View File

@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.InputPair)
err = manager.AddRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()

View File

@ -6,13 +6,16 @@ import (
"net/url"
"os"
"reflect"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
@ -53,6 +56,7 @@ type ConfigInput struct {
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
}
// Config Configuration type
@ -64,7 +68,7 @@ type Config struct {
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor bool
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
@ -95,6 +99,9 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility
DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@ -304,12 +311,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor {
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = *input.NetworkMonitor
config.NetworkMonitor = input.NetworkMonitor
updated = true
}
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
@ -357,6 +373,18 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
return updated, nil
}

View File

@ -264,8 +264,10 @@ func (c *ConnectClient) run(
return wrapErr(err)
}
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engineMutex.Unlock()
err = c.engine.Start()
@ -342,6 +344,10 @@ func (c *ConnectClient) Engine() *Engine {
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
@ -349,13 +355,14 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: config.NetworkMonitor,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
}
if config.PreSharedKey != "" {

View File

@ -0,0 +1,6 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@ -0,0 +1,8 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns
@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
@ -116,16 +116,10 @@ func getOSDNSManagerType() (osManagerType, error) {
}
}
if strings.Contains(text, "resolvconf") {
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil
}
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
}
}

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -39,6 +39,10 @@ func (w *mocWGIface) Address() iface.WGAddress {
}
}
func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter
}
@ -261,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@ -339,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
@ -797,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
}
privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatalf("build interface wireguard: %v", err)
return nil, err

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns

View File

@ -0,0 +1,20 @@
package dns
import (
"errors"
"fmt"
)
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}
func isSystemdResolvedRunning() bool {
return false
}
func isSystemdResolveConfMode() bool {
return false
}

View File

@ -242,3 +242,25 @@ func getSystemdDbusProperty(property string, store any) error {
return v.Store(store)
}
func isSystemdResolvedRunning() bool {
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
}
func isSystemdResolveConfMode() bool {
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return false
}
var value string
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
return false
}
if value == systemdDbusResolvConfModeForeign {
return true
}
return false
}

View File

@ -1,4 +1,4 @@
//go:build !android
//go:build (linux && !android) || freebsd
package dns
@ -14,11 +14,6 @@ import (
log "github.com/sirupsen/logrus"
)
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@ -78,6 +78,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
select {
case <-u.ctx.Done():

View File

@ -2,12 +2,17 @@
package dns
import "github.com/netbirdio/netbird/iface"
import (
"net"
"github.com/netbirdio/netbird/iface"
)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper

View File

@ -10,6 +10,7 @@ import (
"net/netip"
"reflect"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
@ -30,12 +31,15 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
@ -43,6 +47,7 @@ import (
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@ -93,6 +98,8 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
DNSRouteInterval time.Duration
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@ -105,8 +112,8 @@ type Engine struct {
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
beforePeerHook peer.BeforeAddPeerHookFunc
afterPeerHook peer.AfterRemovePeerHookFunc
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@ -159,6 +166,9 @@ type Engine struct {
relayProbe *Probe
wgProbe *Probe
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
}
@ -178,6 +188,7 @@ func NewEngine(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
return NewEngineWithProbes(
clientCtx,
@ -192,6 +203,7 @@ func NewEngine(
nil,
nil,
nil,
checks,
)
}
@ -209,6 +221,7 @@ func NewEngineWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
clientCtx: clientCtx,
@ -230,6 +243,7 @@ func NewEngineWithProbes(
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
checks: checks,
}
}
@ -277,8 +291,6 @@ func (e *Engine) Start() error {
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
wgIface, err := e.newWgIface()
if err != nil {
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
@ -286,6 +298,9 @@ func (e *Engine) Start() error {
}
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive {
@ -310,7 +325,7 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
@ -498,6 +513,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
@ -505,7 +524,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
}
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
@ -521,8 +560,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on Windows is not supported")
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
@ -595,7 +634,14 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
err := e.mgmClient.Sync(e.ctx, e.handleSync)
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@ -662,6 +708,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
@ -703,19 +763,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
@ -743,15 +790,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
var prefix netip.Prefix
if len(protoRoute.Domains) == 0 {
var err error
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
continue
}
}
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute,
}
routes = append(routes, convertedRoute)
}
@ -1105,7 +1161,8 @@ func (e *Engine) close() {
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
netMap, err := e.mgmClient.GetNetworkMap()
info := system.GetInfo(e.ctx)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, err
}
@ -1134,7 +1191,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default:
}
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
}
func (e *Engine) wgInterfaceCreate() (err error) {
@ -1309,6 +1366,15 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
func (e *Engine) restartEngine() {
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
@ -1317,17 +1383,54 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
debounceTimer.Stop()
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(1*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@ -17,6 +17,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@ -58,9 +59,9 @@ var (
)
func TestEngine_SSH(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH on Windows")
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
t.Skip("skipping TestEngine_SSH")
}
key, err := wgtypes.GeneratePrivateKey()
@ -80,7 +81,7 @@ func TestEngine_SSH(t *testing.T) {
WgPort: 33100,
ServerSSHAllowed: true,
},
MobileDependency{}, peer.NewRecorder("https://mgm"))
MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -176,7 +177,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err)
}
//time.Sleep(250 * time.Millisecond)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@ -215,16 +216,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
@ -397,7 +398,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@ -412,7 +413,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@ -572,13 +573,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
@ -743,14 +744,14 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
@ -816,13 +817,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(dir)
mgmtServer, mgmtAddr, err := startManagement(t, dir)
if err != nil {
t.Fatal(err)
return
@ -1015,12 +1016,14 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}
func startSignal() (*grpc.Server, string, error) {
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
@ -1028,7 +1031,9 @@ func startSignal() (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
srv, err := signalServer.NewServer(otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
@ -1039,7 +1044,9 @@ func startSignal() (*grpc.Server, string, error) {
return s, lis.Addr().String(), nil
}
func startManagement(dataDir string) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
@ -1056,23 +1063,25 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, _, err := server.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@ -5,8 +5,6 @@ package networkmonitor
import (
"context"
"fmt"
"net"
"net/netip"
"syscall"
"unsafe"
@ -14,10 +12,10 @@ import (
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err)
@ -47,24 +45,6 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle interface state changes
case unix.RTM_IFINFO:
ifinfo, err := parseInterfaceMessage(buf[:n])
if err != nil {
log.Errorf("Network monitor: error parsing interface message: %v", err)
continue
}
if msg.Flags&unix.IFF_UP != 0 {
continue
}
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
continue
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
go callback()
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
@ -86,7 +66,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback()
case unix.RTM_DELETE:
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback()
}
@ -96,25 +76,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
}
}
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.InterfaceMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return msg, nil
}
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
@ -129,5 +91,5 @@ func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return routemanager.MsgToRoute(msg)
return systemops.MsgToRoute(msg)
}

View File

@ -6,14 +6,13 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
@ -29,23 +28,22 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 netip.Addr
var intf4, intf6 *net.Interface
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
@ -65,7 +63,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
}
}()
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}

View File

@ -6,27 +6,22 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
if intfv4 == nil && intfv6 == nil {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available")
}
linkChan := make(chan netlink.LinkUpdate)
done := make(chan struct{})
defer close(done)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
return fmt.Errorf("subscribe to link updates: %v", err)
}
routeChan := make(chan netlink.RouteUpdate)
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
return fmt.Errorf("subscribe to route updates: %v", err)
@ -38,25 +33,6 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
case <-ctx.Done():
return ErrStopped
// handle interface state changes
case update := <-linkChan:
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
continue
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
go callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
go callback()
return nil
}
}
// handle route changes
case route := <-routeChan:
// default route and main table
@ -70,7 +46,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
go callback()
return nil
case syscall.RTM_DELROUTE:
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil

View File

@ -5,11 +5,12 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
const (
@ -25,20 +26,16 @@ const (
const interval = 10 * time.Second
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
var neighborv4, neighborv6 *routemanager.Neighbor
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
var neighborv4, neighborv6 *systemops.Neighbor
{
initialNeighbors, err := getNeighbors()
if err != nil {
return fmt.Errorf("get neighbors: %w", err)
}
if n, ok := initialNeighbors[nexthopv4]; ok {
neighborv4 = &n
}
if n, ok := initialNeighbors[nexthopv6]; ok {
neighborv6 = &n
}
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
}
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
@ -50,7 +47,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
case <-ctx.Done():
return ErrStopped
case <-ticker.C:
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
go callback()
return nil
}
@ -58,13 +55,21 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
}
}
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
if n, ok := initialNeighbors[nexthop.IP]; ok &&
n.State != unreachable &&
n.State != incomplete &&
n.State != tbd {
return &n
}
return nil
}
func changed(
nexthopv4 netip.Addr,
intfv4 *net.Interface,
neighborv4 *routemanager.Neighbor,
nexthopv6 netip.Addr,
intfv6 *net.Interface,
neighborv6 *routemanager.Neighbor,
nexthopv4 systemops.Nexthop,
neighborv4 *systemops.Neighbor,
nexthopv6 systemops.Nexthop,
neighborv6 *systemops.Neighbor,
) bool {
neighbors, err := getNeighbors()
if err != nil {
@ -81,7 +86,7 @@ func changed(
return false
}
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
return true
}
@ -89,44 +94,74 @@ func changed(
}
// routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
if !nexthop.IsValid() {
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
if !nexthop.IP.IsValid() {
return false
}
var unspec netip.Prefix
if nexthop.Is6() {
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
intf := "<nil>"
if r.Interface != nil {
intf = r.Interface.Name
}
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
return true
}
} else {
log.Infof("network monitor: default route is gone")
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
if !foundMatchingRoute {
logRouteChange(nexthop.IP, intf)
return true
}
return false
}
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
if ip.Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
for _, r := range routes {
if r.Destination == unspec {
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
}
}
return defaultRoutes, foundMatchingRoute
}
func formatRouteInfo(r systemops.Route) string {
newIntf := "<nil>"
if r.Interface != nil {
newIntf = r.Interface.Name
}
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
}
func logRouteChange(ip netip.Addr, intf *net.Interface) {
oldIntf := "<nil>"
if intf != nil {
oldIntf = intf.Name
}
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
}
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
if neighbor == nil {
return false
}
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop]; ok {
if n.State != reachable && n.State != permanent {
if n, ok := neighbors[nexthop.IP]; ok {
if n.State == unreachable || n.State == incomplete {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
@ -150,13 +185,13 @@ func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighb
return false
}
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
entries, err := routemanager.GetNeighbors()
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
entries, err := systemops.GetNeighbors()
if err != nil {
return nil, fmt.Errorf("get neighbors: %w", err)
}
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
for _, entry := range entries {
neighbours[entry.IPAddress] = entry
}
@ -164,18 +199,13 @@ func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
return neighbours, nil
}
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
entries, err := routemanager.GetRoutes()
func getRoutes() ([]systemops.Route, error) {
entries, err := systemops.GetRoutes()
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
for _, entry := range entries {
routes[entry.Destination] = entry
}
return routes, nil
return entries, nil
}
func stateFromInt(state uint8) string {

View File

@ -62,9 +62,6 @@ type ConnConfig struct {
ICEConfig ICEConfig
}
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
type WorkerCallbacks struct {
OnRelayReadyCallback func(info RelayConnInfo)
OnRelayStatusChanged func(ConnStatus)
@ -99,8 +96,8 @@ type Conn struct {
workerRelay *WorkerRelay
connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
endpointRelay *net.UDPAddr
@ -266,11 +263,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}

View File

@ -46,7 +46,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@ -61,7 +61,7 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@ -134,7 +134,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
@ -172,7 +172,7 @@ func TestConn_Status(t *testing.T) {
func TestConn_Switch(t *testing.T) {
ctx := context.Background()
wgProxyFactory := wgproxy.NewFactory(ctx, connConf.LocalWgPort)
wgProxyFactory := wgproxy.NewFactory(ctx, false, connConf.LocalWgPort)
connConfAlice := ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",

View File

@ -2,14 +2,17 @@ package peer
import (
"errors"
"net/netip"
"sync"
"time"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
)
// State contains the latest state of a peer
@ -37,25 +40,25 @@ type State struct {
// AddRoute add a single route to routes map
func (s *State) AddRoute(network string) {
s.Mux.Lock()
defer s.Mux.Unlock()
if s.routes == nil {
s.routes = make(map[string]struct{})
}
s.routes[network] = struct{}{}
s.Mux.Unlock()
}
// SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock()
defer s.Mux.Unlock()
s.routes = routes
s.Mux.Unlock()
}
// DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) {
s.Mux.Lock()
defer s.Mux.Unlock()
delete(s.routes, network)
s.Mux.Unlock()
}
// GetRoutes return routes map
@ -117,22 +120,23 @@ type FullStatus struct {
// Status holds a state of peers, signal, management connections and relays
type Status struct {
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
mux sync.Mutex
peers map[string]State
changeNotify map[string]chan struct{}
signalState bool
signalError error
managementState bool
managementError error
relayStates []relay.ProbeResult
localPeer LocalPeerState
offlinePeers []State
mgmAddress string
signalAddress string
notifier *notifier
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@ -143,11 +147,12 @@ type Status struct {
// NewRecorder returns a new Status instance
func NewRecorder(mgmAddress string) *Status {
return &Status{
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
peers: make(map[string]State),
changeNotify: make(map[string]chan struct{}),
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
}
}
@ -188,7 +193,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey]
if !ok {
return State{}, errors.New("peer not found")
return State{}, iface.ErrPeerNotFound
}
return state, nil
}
@ -429,6 +434,18 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
}
func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{
d.rosenpassEnabled,
@ -493,6 +510,12 @@ func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()

View File

@ -3,19 +3,20 @@ package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
const minRangeBits = 7
type routerPeerStatus struct {
connected bool
relayed bool
@ -28,33 +29,42 @@ type routesUpdate struct {
routes []*route.Route
}
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type clientNetwork struct {
ctx context.Context
stop context.CancelFunc
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{}
chosenRoute *route.Route
network netip.Prefix
currentChosen *route.Route
handler RouteHandler
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
ctx: ctx,
stop: cancel,
cancel: cancel,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
network: network,
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
}
return client
}
@ -86,8 +96,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
@ -96,8 +106,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
currScore := float64(0)
currID := route.ID("")
if c.chosenRoute != nil {
currID = c.chosenRoute.ID
if c.currentChosen != nil {
currID = c.currentChosen.ID
}
for _, r := range c.routes {
@ -151,18 +161,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
peers = append(peers, r.Peer)
}
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
return currID
}
var p string
if rt := c.routes[chosen]; rt != nil {
p = rt.Peer
}
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
}
return chosen
@ -196,98 +206,103 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Errorf("get peer state: %v", err)
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
state.DeleteRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
}
return nil
}
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
}
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route: %v", err)
}
if c.currentChosen == nil {
return nil
}
return nil
var merr *multierror.Error
if err := c.removeRouteFromWireguardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
routerPeerStatuses := c.getRouterPeerStatuses()
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system
if chosen == "" {
if newChosenID == "" {
if err := c.removeRouteFromPeerAndSystem(); err != nil {
return fmt.Errorf("remove route from peer and system: %v", err)
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
}
c.chosenRoute = nil
c.currentChosen = nil
return nil
}
// If the chosen route is the same as the current route, do nothing
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil
}
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.IsEqual(c.routes[newChosenID]) {
return nil
}
if c.chosenRoute != nil {
// If a previous route exists, remove it from the peer
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
return fmt.Errorf("remove route from peer: %v", err)
if c.currentChosen == nil {
// If they were not previously assigned to another peer, add routes to the system first
if err := c.handler.AddRoute(c.ctx); err != nil {
return fmt.Errorf("add route: %w", err)
}
} else {
// otherwise add the route to the system
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.Address().IP.String(), err)
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
c.chosenRoute = c.routes[chosen]
c.currentChosen = c.routes[newChosenID]
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
} else {
state.AddRoute(c.network.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
c.addStateRoute()
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update
@ -318,24 +333,23 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
for {
select {
case <-c.ctx.Done():
log.Debugf("stopping watcher for network %s", c.network)
err := c.removeRouteFromPeerAndSystem()
if err != nil {
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
log.Debugf("Stopping watcher for network [%v]", c.handler)
if err := c.removeRouteFromPeerAndSystem(); err != nil {
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
}
return
case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number, ignoring it")
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
continue
}
log.Debugf("Received a new client network route update for %s", c.network)
log.Debugf("Received a new client network route update for [%v]", c.handler)
c.handleUpdate(update)
@ -343,7 +357,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
c.startPeersStatusChangeWatcher()
@ -351,14 +365,9 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
func (c *clientNetwork) getAsInterface() *net.Interface {
intf, err := net.InterfaceByName(c.wgInterface.Name())
if err != nil {
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
if rt.IsDynamic() {
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
}
return intf
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
@ -340,9 +341,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
// create new clientNetwork
client := &clientNetwork{
network: netip.MustParsePrefix("192.168.0.0/24"),
routes: tc.existingRoutes,
chosenRoute: currentRoute,
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes,
currentChosen: currentRoute,
}
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@ -0,0 +1,378 @@
package dynamic
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
failureInterval = 5 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
type domainMap map[domain.Domain][]netip.Prefix
type resolveResult struct {
domain domain.Domain
prefix netip.Prefix
err error
}
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
interval time.Duration
dynamicDomains domainMap
mu sync.Mutex
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
}
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
}
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
}
func (r *Route) AddRoute(ctx context.Context) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
ctx, r.cancel = context.WithCancel(ctx)
go r.startResolver(ctx)
return nil
}
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
func (r *Route) RemoveRoute() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
var merr *multierror.Error
for domain, prefixes := range r.dynamicDomains {
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
}
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
r.statusRecorder.DeleteResolvedDomainsStates(domain)
}
r.dynamicDomains = domainMap{}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) AddAllowedIPs(peerKey string) error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for domain, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
}
r.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) RemoveAllowedIPs() error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for _, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
}
r.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
interval := r.interval
if interval < minInterval {
interval = minInterval
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
if interval > failureInterval {
ticker.Reset(failureInterval)
}
}
for {
select {
case <-ctx.Done():
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
if err := r.update(ctx); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
// Use a lower ticker interval if the update fails
if interval > failureInterval {
ticker.Reset(failureInterval)
}
} else if interval > failureInterval {
// Reset to the original interval if the update succeeds
ticker.Reset(interval)
}
}
}
}
func (r *Route) update(ctx context.Context) error {
if resolved, err := r.resolveDomains(); err != nil {
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
return nil
}
func (r *Route) resolveDomains() (domainMap, error) {
results := make(chan resolveResult)
go r.resolve(results)
resolved := domainMap{}
var merr *multierror.Error
for result := range results {
if result.err != nil {
merr = multierror.Append(merr, result.err)
} else {
resolved[result.domain] = append(resolved[result.domain], result.prefix)
}
}
return resolved, nberrors.FormatErrorOrNil(merr)
}
func (r *Route) resolve(results chan resolveResult) {
var wg sync.WaitGroup
for _, d := range r.route.Domains {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
return
}
results <- resolveResult{domain: domain, prefix: prefix}
}
}(d)
}
wg.Wait()
close(results)
}
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
r.mu.Lock()
defer r.mu.Unlock()
if ctx.Err() != nil {
return ctx.Err()
}
var merr *multierror.Error
for domain, newPrefixes := range newDomains {
oldPrefixes := r.dynamicDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
addedPrefixes, err := r.addRoutes(domain, toAdd)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(addedPrefixes) > 0 {
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
}
removedPrefixes, err := r.removeRoutes(toRemove)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(removedPrefixes) > 0 {
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
}
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
var addedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
if r.currentPeerKey != "" {
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
addedPrefixes = append(addedPrefixes, prefix)
}
return addedPrefixes, merr.ErrorOrNil()
}
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
if r.route.KeepRoute {
return nil, nil
}
var removedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
if r.currentPeerKey != "" {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
removedPrefixes = append(removedPrefixes, prefix)
}
return removedPrefixes, merr.ErrorOrNil()
}
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
return fmt.Errorf(addAllowedIP, prefix, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
prefixSet := make(map[netip.Prefix]struct{})
for _, prefix := range oldPrefixes {
prefixSet[prefix] = struct{}{}
}
for _, prefix := range removedPrefixes {
delete(prefixSet, prefix)
}
for _, prefix := range addedPrefixes {
prefixSet[prefix] = struct{}{}
}
var combinedPrefixes []netip.Prefix
for prefix := range prefixSet {
combinedPrefixes = append(combinedPrefixes, prefix)
}
return combinedPrefixes
}

View File

@ -2,18 +2,23 @@ package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
@ -21,14 +26,9 @@ import (
"github.com/netbirdio/netbird/version"
)
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
// nolint:unused
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
@ -40,31 +40,71 @@ type Manager interface {
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
}
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface *iface.WGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
ctx: mCTX,
stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: newNotifier(),
}
dm.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ any) (any, error) {
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ any) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
dm.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
}
return nil
},
)
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr)
@ -73,12 +113,12 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
}
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
if err := m.sysOps.CleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@ -86,7 +126,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee
signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}
@ -110,8 +150,19 @@ func (m *DefaultManager) Stop() {
m.serverRouter.cleanUp()
}
if m.routeRefCounter != nil {
if err := m.routeRefCounter.Flush(); err != nil {
log.Errorf("Error flushing route ref counter: %v", err)
}
}
if m.allowedIPsRefCounter != nil {
if err := m.allowedIPsRefCounter.Flush(); err != nil {
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
}
}
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
if err := m.sysOps.CleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
@ -185,7 +236,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@ -197,7 +248,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
client.stop()
client.cancel()
delete(m.clientNetworks, id)
}
}
@ -210,7 +261,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
@ -228,7 +279,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
ownNetworkIDs := make(map[route.HAUniqueID]bool)
for _, newRoute := range newRoutes {
haID := route.GetHAUniqueID(newRoute)
haID := newRoute.GetHAUniqueID()
if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true
// only linux is supported for now
@ -241,9 +292,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
}
for _, newRoute := range newRoutes {
haID := route.GetHAUniqueID(newRoute)
haID := newRoute.GetHAUniqueID()
if !ownNetworkIDs[haID] {
if !isPrefixSupported(newRoute.Network) {
if !isRouteSupported(newRoute) {
continue
}
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
@ -255,23 +306,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0)
rs := make([]*route.Route, 0, len(crMap))
for _, routes := range crMap {
rs = append(rs, routes...)
}
return rs
}
func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() {
func isRouteSupported(route *route.Route) bool {
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
return true
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management
if prefix.Bits() <= minRangeBits {
if route.Network.Bits() <= vars.MinRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), prefix)
version.NetbirdVersion(), route.Network)
return false
}
return true

View File

@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
_, _, err = routeManager.Init()
@ -436,7 +436,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
if testCase.clientNetworkWatchersExpectedAllowed != 0 {
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
}
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")

View File

@ -6,10 +6,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
)
// MockManager is the mock instance of a route manager
@ -20,7 +20,7 @@ type MockManager struct {
StopFunc func()
}
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil
}

View File

@ -0,0 +1,155 @@
package refcounter
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
var ErrIgnore = errors.New("ignore")
type Ref[O any] struct {
Count int
Out O
}
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
type Counter[I, O any] struct {
// refCountMap keeps track of the reference Ref for prefixes
refCountMap map[netip.Prefix]Ref[O]
refCountMu sync.Mutex
// idMap keeps track of the prefixes associated with an ID for removal
idMap map[string][]netip.Prefix
idMu sync.Mutex
add AddFunc[I, O]
remove RemoveFunc[I, O]
}
// New creates a new Counter instance
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
return &Counter[I, O]{
refCountMap: map[netip.Prefix]Ref[O]{},
idMap: map[string][]netip.Prefix{},
add: add,
remove: remove,
}
}
// Increment increments the reference count for the given prefix.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[prefix]
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
// Call AddFunc only if it's a new prefix
if ref.Count == 0 {
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
out, err := rm.add(prefix, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
}
ref.Out = out
}
ref.Count++
rm.refCountMap[prefix] = ref
return ref, nil
}
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.Increment(prefix, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
rm.idMap[id] = append(rm.idMap[id], prefix)
return ref, nil
}
// Decrement decrements the reference count for the given prefix.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[prefix]
if !ok {
log.Tracef("No reference found for prefix %s", prefix)
return ref, nil
}
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
if ref.Count == 1 {
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
if err := rm.remove(prefix, ref.Out); err != nil {
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
}
delete(rm.refCountMap, prefix)
} else {
ref.Count--
rm.refCountMap[prefix] = ref
}
return ref, nil
}
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, prefix := range rm.idMap[id] {
if _, err := rm.Decrement(prefix); err != nil {
merr = multierror.Append(merr, err)
}
}
delete(rm.idMap, id)
return nberrors.FormatErrorOrNil(merr)
}
// Flush removes all references and calls RemoveFunc for each prefix.
func (rm *Counter[I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for prefix := range rm.refCountMap {
log.Tracef("Removing for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.remove(prefix, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]Ref[O]{}
rm.idMap = map[string][]netip.Prefix{}
return nberrors.FormatErrorOrNil(merr)
}

View File

@ -0,0 +1,7 @@
package refcounter
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
type RouteRefCounter = Counter[any, any]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
type AllowedIPsRefCounter = Counter[string, string]

View File

@ -1,127 +0,0 @@
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
type ref struct {
count int
nexthop netip.Addr
intf *net.Interface
}
type RouteManager struct {
// refCountMap keeps track of the reference ref for prefixes
refCountMap map[netip.Prefix]ref
// prefixMap keeps track of the prefixes associated with a connection ID for removal
prefixMap map[nbnet.ConnectionID][]netip.Prefix
addRoute AddRouteFunc
removeRoute RemoveRouteFunc
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap
return &RouteManager{
refCountMap: map[netip.Prefix]ref{},
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
addRoute: addRoute,
removeRoute: removeRoute,
}
}
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
ref := rm.refCountMap[prefix]
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
// Add route to the system, only if it's a new prefix
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, ErrRouteNotFound) {
return nil
}
if errors.Is(err, ErrRouteNotAllowed) {
log.Debugf("Adding route for prefix %s: %s", prefix, err)
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}
ref.nexthop = nexthop
ref.intf = intf
}
ref.count++
rm.refCountMap[prefix] = ref
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
return nil
}
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
prefixes, ok := rm.prefixMap[connID]
if !ok {
log.Debugf("No prefixes found for connection ID %s", connID)
return nil
}
var result *multierror.Error
for _, prefix := range prefixes {
ref := rm.refCountMap[prefix]
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
if ref.count == 1 {
log.Debugf("Removing route for prefix %s", prefix)
// TODO: don't fail if the route is not found
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
continue
}
delete(rm.refCountMap, prefix)
} else {
ref.count--
rm.refCountMap[prefix] = ref
}
}
delete(rm.prefixMap, connID)
return result.ErrorOrNil()
}
// Flush removes all references and routes from the system
func (rm *RouteManager) Flush() error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
var result *multierror.Error
for prefix := range rm.refCountMap {
log.Debugf("Removing route for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]ref{}
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
return result.ErrorOrNil()
}

View File

@ -12,6 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
@ -70,7 +71,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
}
if len(m.routes) > 0 {
err := enableIPForwarding()
err := systemops.EnableIPForwarding()
if err != nil {
return err
}
@ -88,7 +89,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
@ -117,7 +118,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
@ -133,7 +134,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
state.Routes[route.Network.String()] = struct{}{}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
@ -144,7 +151,7 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
routerPair, err := routeToRouterPair(r)
if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err)
continue
@ -162,15 +169,27 @@ func (m *defaultServerRouter) cleanUp() {
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
parsed, err := netip.ParsePrefix(source)
if err != nil {
return firewall.RouterPair{}, err
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
destination := route.Network.Masked().String()
if route.IsDynamic() {
// TODO: add ipv6
destination = "0.0.0.0/0"
}
return firewall.RouterPair{
ID: string(route.ID),
Source: parsed.String(),
Destination: route.Network.Masked().String(),
Source: source.String(),
Destination: destination,
Masquerade: route.Masquerade,
}, nil
}
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@ -0,0 +1,57 @@
package static
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}
// Route route methods
func (r *Route) String() string {
return r.route.Network.String()
}
func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
return err
}
func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network)
return err
}
func (r *Route) AddAllowedIPs(peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
r.route.Network,
ref.Out,
)
}
return nil
}
func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
return err
}

View File

@ -0,0 +1,103 @@
// go:build !android
package sysctl
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/iface"
)
const (
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := Set(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = Set(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := Set(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, nberrors.FormatErrorOrNil(result)
}
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
// Cleanup resets sysctl settings to their original values.
func Cleanup(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := Set(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@ -1,414 +0,0 @@
//go:build !android && !ios
package routemanager
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRouteNotFound = errors.New("route not found")
var ErrRouteNotAllowed = errors.New("route not allowed")
// TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
defaultGateway, _, err := GetNextHop(addr)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(defaultGateway) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
if defaultGateway.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
gatewayHop, intf, err := GetNextHop(defaultGateway)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
}
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
r, err := netroute.New()
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if preferredSrc == nil {
return netip.Addr{}, nil, ErrRouteNotFound
}
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
}
return addr.Unmap(), intf, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
}
return addr, intf, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := GetNextHop(addr)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@ -0,0 +1,18 @@
//go:build darwin || dragonfly || netbsd || openbsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
return false
}

View File

@ -0,0 +1,19 @@
//go:build: freebsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
return false
}

View File

@ -0,0 +1,27 @@
package systemops
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
IP netip.Addr
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}
}

View File

@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager
package systemops
import (
"errors"
@ -43,8 +43,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
}
if m.Flags&syscall.RTF_UP == 0 ||
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
if filterRoutesByFlags(m.Flags) {
continue
}
@ -93,7 +92,7 @@ func toNetIP(a route.Addr) netip.Addr {
case *route.Inet6Addr:
ip := netip.AddrFrom16(t.IP)
if t.ZoneID != 0 {
ip.WithZone(strconv.Itoa(t.ZoneID))
ip = ip.WithZone(strconv.Itoa(t.ZoneID))
}
return ip
default:
@ -101,6 +100,7 @@ func toNetIP(a route.Addr) netip.Addr {
}
}
// ones returns the number of leading ones in the mask.
func ones(a route.Addr) (int, error) {
switch t := a.(type) {
case *route.Inet4Addr:
@ -114,6 +114,7 @@ func ones(a route.Addr) (int, error) {
}
}
// MsgToRoute converts a route message to a Route.
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]

View File

@ -1,6 +1,6 @@
//go:build !ios
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager
package systemops
import (
"fmt"
@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/route"
)
var expectedVPNint = "utun100"
@ -35,13 +36,15 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
@ -57,7 +60,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
@ -67,6 +70,53 @@ func TestConcurrentRoutes(t *testing.T) {
wg.Wait()
}
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()

View File

@ -0,0 +1,473 @@
//go:build !android && !ios
package systemops
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
refCounter := refcounter.New(
func(prefix netip.Prefix, _ any) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
}
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
},
r.removeFromRouteTable,
)
r.refCounter = refCounter
return r.setupHooks(initAddresses)
}
func (r *SysOps) cleanupRefCounter() error {
if r.refCounter == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
exitNextHop := Nexthop{
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(result)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}
func GetNextHop(ip netip.Addr) (Nexthop, error) {
r, err := netroute.New()
if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
zone := intf.Name
if runtime.GOOS == "windows" {
zone = strconv.Itoa(intf.Index)
}
log.Tracef("Adding zone %s to address %s", zone, addr)
addr = addr.WithZone(zone)
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()
if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}
return isVpnRoute(addr, vpnRoutes, localRoutes)
}
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
vpnPrefixMap := map[netip.Prefix]struct{}{}
for _, prefix := range vpnRoutes {
vpnPrefixMap[prefix] = struct{}{}
}
// remove vpnRoute duplicates
for _, prefix := range localRoutes {
delete(vpnPrefixMap, prefix)
}
var longestPrefix netip.Prefix
var isVpn bool
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
copy(combinedRoutes, vpnRoutes)
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
for _, prefix := range combinedRoutes {
// Ignore the default route, it has special handling
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
// Longest prefix match
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
longestPrefix = prefix
_, isVpn = vpnPrefixMap[prefix]
}
}
}
if !longestPrefix.IsValid() {
// No route matched
return false, netip.Prefix{}
}
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}

View File

@ -1,6 +1,6 @@
//go:build !android && !ios
package routemanager
package systemops
import (
"bytes"
@ -49,6 +49,10 @@ func TestAddRemoveRoutes(t *testing.T) {
}
for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
@ -57,23 +61,26 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, wgInterface)
r := NewSysOps(wgInterface)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
assert.NoError(t, r.CleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf)
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
@ -84,19 +91,19 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = removeVPNRoute(testCase.prefix, intf)
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
}
})
@ -104,11 +111,14 @@ func TestAddRemoveRoutes(t *testing.T) {
}
func TestGetNextHop(t *testing.T) {
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if runtime.GOOS == "freebsd" {
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !gateway.IsValid() {
if !nexthop.IP.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@ -130,24 +140,24 @@ func TestGetNextHop(t *testing.T) {
}
}
localIP, _, err := GetNextHop(testingPrefix.Addr())
localIP, err := GetNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if !localIP.IsValid() {
if !localIP.IP.IsValid() {
t.Fatal("should return a gateway for local network")
}
if localIP.String() == gateway.String() {
t.Fatal("local ip should not match with gateway IP")
if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local IP should not match with gateway IP")
}
if localIP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
if localIP.IP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultGateway: ", defaultGateway)
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
@ -164,7 +174,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
@ -203,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
@ -214,14 +224,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := addVPNRoute(testCase.preExistingPrefix, intf)
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = addVPNRoute(testCase.prefix, intf)
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
@ -231,7 +243,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist")
// remove route again if added
err = removeVPNRoute(testCase.prefix, intf)
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
@ -295,19 +307,22 @@ func TestExistsInRouteTable(t *testing.T) {
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is6() {
continue
}
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
switch {
case p.Addr().Is6():
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
@ -330,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
newNet, err := stdnet.NewNet()
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
@ -343,65 +358,52 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
return wgInterface
}
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgIface.Close())
assert.NoError(t, wgInterface.Close())
})
_, _, err := setupRouting(nil, wgIface)
r := NewSysOps(wgInterface)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())
assert.NoError(t, r.CleanupRouting())
})
index, err := net.InterfaceByName(wgIface.Name())
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
// 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
// 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
@ -410,11 +412,133 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
return
}
prefixGateway, _, err := GetNextHop(prefix.Addr())
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@ -1,6 +1,6 @@
//go:build !android
package routemanager
package systemops
import (
"bufio"
@ -9,16 +9,15 @@ import (
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/util/net"
)
@ -33,16 +32,10 @@ const (
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
@ -82,7 +75,7 @@ func getSetupRules() []ruleParams {
}
}
// setupRouting establishes the routing configuration for the VPN, including essential rules
// SetupRouting establishes the routing configuration for the VPN, including essential rules
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
//
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
@ -92,17 +85,17 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
if isLegacy() {
log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
return r.setupRefCounter(initAddresses)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
@ -111,7 +104,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
if cleanErr := r.CleanupRouting(); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
@ -123,7 +116,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
return r.setupRefCounter(initAddresses)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
}
@ -132,12 +125,12 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
return nil, nil, nil
}
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error {
func (r *SysOps) CleanupRouting() error {
if isLegacy() {
return cleanupRoutingWithRouteManager(routeManager)
return r.cleanupRefCounter()
}
var result *multierror.Error
@ -156,58 +149,58 @@ func cleanupRouting() error {
}
}
if err := cleanupSysctl(originalSysctl); err != nil {
if err := sysctl.Cleanup(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
return nberrors.FormatErrorOrNil(result)
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericAddVPNRoute(prefix, intf)
return r.genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
if prefix == vars.Defaultv4 {
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
}
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return genericRemoveVPNRoute(prefix, intf)
return r.genericRemoveVPNRoute(prefix, intf)
}
// TODO remove this once we have ipv6 support
if prefix == defaultv4 {
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
if prefix == vars.Defaultv4 {
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
}
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
@ -255,7 +248,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
@ -268,7 +261,7 @@ func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID
}
route.Dst = ipNet
if err := addNextHop(addr, intf, route); err != nil {
if err := addNextHop(nexthop, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
@ -327,7 +320,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
@ -340,7 +333,7 @@ func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tabl
Dst: ipNet,
}
if err := addNextHop(addr, intf, route); err != nil {
if err := addNextHop(nexthop, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
@ -373,11 +366,11 @@ func flushRoutes(tableID, family int) error {
}
}
return result.ErrorOrNil()
return nberrors.FormatErrorOrNil(result)
}
func enableIPForwarding() error {
_, err := setSysctl(ipv4ForwardingPath, 1, false)
func EnableIPForwarding() error {
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
return err
}
@ -481,19 +474,19 @@ func removeRule(params ruleParams) error {
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
if intf != nil {
route.LinkIndex = intf.Index
func addNextHop(nexthop Nexthop, route *netlink.Route) error {
if nexthop.Intf != nil {
route.LinkIndex = nexthop.Intf.Index
}
if addr.IsValid() {
route.Gw = addr.AsSlice()
if nexthop.IP.IsValid() {
route.Gw = nexthop.IP.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil {
link, err := netlink.LinkByName(addr.Zone())
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
link, err := netlink.LinkByName(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
}
@ -509,82 +502,9 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() {
return getRoutesFromTable()
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
return nil, ErrRoutingIsSeparate
}

View File

@ -1,6 +1,6 @@
//go:build !android
package routemanager
package systemops
import (
"errors"
@ -14,6 +14,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
var expectedVPNint = "wgtest0"
@ -138,7 +140,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
if dstIPNet.String() == "0.0.0.0/0" {
var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, ErrRouteNotFound) {
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err)
}
@ -193,7 +195,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
}
}
return nil, 0, ErrRouteNotFound
return nil, 0, vars.ErrRouteNotFound
}
func setupDummyInterfacesAndRoutes(t *testing.T) {

View File

@ -0,0 +1,38 @@
//go:build ios || android
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
return nil
}
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}

View File

@ -0,0 +1,28 @@
//go:build !linux && !ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericRemoveVPNRoute(prefix, intf)
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return getRoutesFromTable()
}

View File

@ -1,6 +1,6 @@
//go:build darwin && !ios
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
package routemanager
package systemops
import (
"fmt"
@ -13,43 +13,41 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
func (r *SysOps) CleanupRouting() error {
return r.cleanupRefCounter()
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("add", prefix, nexthop, intf)
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
return routeCmd("delete", prefix, nexthop, intf)
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop)
}
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
} else if intf != nil {
args = append(args, "-interface", intf.Name)
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {

View File

@ -1,10 +1,11 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package routemanager
package systemops
import (
"fmt"
"net"
"runtime"
"strings"
"testing"
"time"
@ -85,6 +86,10 @@ var testCases = []testCase{
func TestRouting(t *testing.T) {
for _, tc := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", tc.name, " on freebsd")
}
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)

View File

@ -1,6 +1,6 @@
//go:build windows
package routemanager
package systemops
import (
"fmt"
@ -17,8 +17,7 @@ import (
"github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
type MSFT_NetRoute struct {
@ -57,14 +56,42 @@ var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
var routeManager *RouteManager
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses)
}
func cleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager)
func (r *SysOps) CleanupRouting() error {
return r.cleanupRefCounter()
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
nexthop.Intf = &net.Interface{Index: zone}
}
return addRouteCmd(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
@ -93,7 +120,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
@ -118,6 +145,10 @@ func GetRoutes() ([]Route, error) {
Index: int(entry.InterfaceIndex),
Name: entry.InterfaceAlias,
}
if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
}
}
routes = append(routes, Route{
@ -157,11 +188,12 @@ func GetNeighbors() ([]Neighbor, error) {
return neighbors, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
if nexthop.IsValid() {
args = append(args, nexthop.Unmap().String())
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
@ -170,8 +202,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
args = append(args, addr)
}
if intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index))
if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
@ -185,37 +217,6 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
return nil
}
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@ -1,4 +1,4 @@
package routemanager
package systemops
import (
"context"
@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"`
NextHop string `json:"NextHop"`
NextHop string `json:"Nexthop"`
DestinationPrefix string `json:"DestinationPrefix"`
}
@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
host, _, err := net.SplitHostPort(destination)
require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute")
@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)

View File

@ -1,33 +0,0 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@ -1,57 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/route"
)
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@ -1,33 +0,0 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@ -1,24 +0,0 @@
//go:build !linux && !ios
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@ -0,0 +1,29 @@
package util
import (
"fmt"
"net"
"net/netip"
)
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return prefix, nil
}

View File

@ -0,0 +1,16 @@
package vars
import (
"errors"
"net/netip"
)
const MinRangeBits = 7
var (
ErrRouteNotFound = errors.New("route not found")
ErrRouteNotAllowed = errors.New("route not allowed")
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
)

View File

@ -3,11 +3,11 @@ package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
route "github.com/netbirdio/netbird/route"
)
@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.selectedRoutes = map[route.NetID]struct{}{}
}
var multiErr *multierror.Error
var err *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue
}
@ -41,11 +41,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
}
rs.selectAll = false
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
return errors.FormatErrorOrNil(err)
}
// SelectAllRoutes sets the selector to select all routes.
@ -65,21 +61,17 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
}
}
var multiErr *multierror.Error
var err *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
return errors.FormatErrorOrNil(err)
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
@ -111,18 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
require.NoError(t, err)
routes := route.HAMap{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route3|172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, route.HAMap{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
}, filtered)
}

View File

@ -4,21 +4,24 @@ package wgproxy
import (
"context"
log "github.com/sirupsen/logrus"
)
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
// todo: put it back
/*
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
f.ebpfProxy = ebpfProxy
if userspace {
return f
}
*/
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
err := ebpfProxy.listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
f.ebpfProxy = ebpfProxy
return f
}

View File

@ -4,6 +4,6 @@ package wgproxy
import "context"
func NewFactory(ctx context.Context, wgPort int) *Factory {
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}

File diff suppressed because it is too large Load Diff

View File

@ -92,6 +92,8 @@ message LoginRequest {
repeated string extraIFaceBlacklist = 17;
optional bool networkMonitor = 18;
optional google.protobuf.Duration dnsRouteInterval = 19;
}
message LoginResponse {
@ -145,6 +147,18 @@ message GetConfigResponse {
// adminURL settings value.
string adminURL = 5;
string interfaceName = 6;
int64 wireguardPort = 7;
bool disableAutoConnect = 9;
bool serverSSHAllowed = 10;
bool rosenpassEnabled = 11;
bool rosenpassPermissive = 12;
}
// PeerState contains the latest state of a peer
@ -233,10 +247,17 @@ message SelectRoutesRequest {
message SelectRoutesResponse {
}
message IPList {
repeated string ips = 1;
}
message Route {
string ID = 1;
string network = 2;
bool selected = 3;
repeated string domains = 4;
map<string, IPList> resolvedIPs = 5;
}
message DebugBundleRequest {

View File

@ -9,17 +9,19 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
type selectRoute struct {
NetID route.NetID
Network netip.Prefix
Domains domain.List
Selected bool
}
// ListRoutes returns a list of all available routes.
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -43,6 +45,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
route := &selectRoute{
NetID: id,
Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
@ -63,13 +66,29 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
return iPrefix < jPrefix
})
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
var pbRoutes []*proto.Route
for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{
ID: string(route.NetID),
Network: route.Network.String(),
Selected: route.Selected,
})
pbRoute := &proto.Route{
ID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToSafeStringList(),
ResolvedIPs: map[string]*proto.IPList{},
Selected: route.Selected,
}
for _, domain := range route.Domains {
if prefixes, exists := resolvedDomains[domain]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
Ips: ipStrings,
}
}
}
pbRoutes = append(pbRoutes, pbRoute)
}
return &proto.ListRoutesResponse{

View File

@ -365,6 +365,12 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
}
if msg.DnsRouteInterval != nil {
duration := msg.DnsRouteInterval.AsDuration()
inputConfig.DNSRouteInterval = &duration
s.latestConfigInput.DNSRouteInterval = &duration
}
s.mutex.Unlock()
if msg.OptionalPreSharedKey != nil {
@ -662,11 +668,17 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
}
return &proto.GetConfigResponse{
ManagementUrl: managementURL,
AdminURL: adminURL,
ConfigFile: s.latestConfigInput.ConfigPath,
LogFile: s.logFile,
PreSharedKey: preSharedKey,
ManagementUrl: managementURL,
ConfigFile: s.latestConfigInput.ConfigPath,
LogFile: s.logFile,
PreSharedKey: preSharedKey,
AdminURL: adminURL,
InterfaceName: s.config.WgIface,
WireguardPort: int64(s.config.WgPort),
DisableAutoConnect: s.config.DisableAutoConnect,
ServerSSHAllowed: *s.config.ServerSSHAllowed,
RosenpassEnabled: s.config.RosenpassEnabled,
RosenpassPermissive: s.config.RosenpassPermissive,
}, nil
}
func (s *Server) onSessionExpire() {

View File

@ -7,6 +7,8 @@ import (
"time"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@ -39,7 +41,7 @@ var (
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal()
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
@ -106,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
@ -117,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}
@ -141,7 +143,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return s, lis.Addr().String(), nil
}
func startSignal() (*grpc.Server, string, error) {
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
@ -149,7 +153,9 @@ func startSignal() (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
srv, err := signalServer.NewServer(otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {

View File

@ -0,0 +1,10 @@
//go:build freebsd
package ssh
import (
"os"
)
func setWinSize(file *os.File, width, height int) {
}

View File

@ -8,6 +8,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/version"
)
@ -33,6 +34,12 @@ type Environment struct {
Platform string
}
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// Info is an object that contains machine information
// Most of the code is taken from https://github.com/matishsiao/goInfo
type Info struct {
@ -51,6 +58,7 @@ type Info struct {
SystemProductName string
SystemManufacturer string
Environment Environment
Files []File // for posture checks
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@ -132,3 +140,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
processCheckPaths := make([]string, 0)
for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
if err != nil {
return nil, err
}
info := GetInfo(ctx)
info.Files = files
return info, nil
}

View File

@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
GoOS: runtime.GOOS,
Kernel: kernel,
Platform: "unknown",
OS: "android",
OS: "Android",
OSVersion: osVersion(),
Hostname: extractDeviceName(ctx, "android"),
CPUs: runtime.NumCPU(),
@ -44,6 +44,11 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
func uname() []string {
res := run("/system/bin/uname", "-a")
return strings.Split(res, " ")
@ -72,5 +77,6 @@ func run(name string, arg ...string) string {
if err != nil {
log.Errorf("getInfo: %s", err)
}
return out.String()
return strings.TrimSpace(out.String())
}

View File

@ -1,15 +1,18 @@
//go:build freebsd
package system
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version"
@ -22,8 +25,8 @@ func GetInfo(ctx context.Context) *Info {
out = _getInfo()
time.Sleep(500 * time.Millisecond)
}
osStr := strings.Replace(out, "\n", "", -1)
osStr = strings.Replace(osStr, "\r\n", "", -1)
osStr := strings.ReplaceAll(out, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
env := Environment{
@ -31,14 +34,23 @@ func GetInfo(ctx context.Context) *Info {
Platform: detect_platform.Detect(ctx),
}
gio := &Info{Kernel: osInfo[0], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: osInfo[1], Environment: env}
osName, osVersion := readOsReleaseFile()
systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
return gio
return &Info{
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(),
WiretrusteeVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
Environment: env,
}
}
func _getInfo() string {
@ -50,7 +62,8 @@ func _getInfo() string {
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
fmt.Println("getInfo:", err)
log.Warnf("getInfo: %s", err)
}
return out.String()
}

View File

@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
// extractOsVersion extracts operating system version from context or returns the default
func extractOsVersion(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(OsVersionCtxKey).(string)

View File

@ -28,28 +28,11 @@ func GetInfo(ctx context.Context) *Info {
time.Sleep(500 * time.Millisecond)
}
releaseInfo := _getReleaseInfo()
for strings.Contains(info, "broken pipe") {
releaseInfo = _getReleaseInfo()
time.Sleep(500 * time.Millisecond)
}
osRelease := strings.Split(releaseInfo, "\n")
var osName string
var osVer string
for _, s := range osRelease {
if strings.HasPrefix(s, "NAME=") {
osName = strings.Split(s, "=")[1]
osName = strings.ReplaceAll(osName, "\"", "")
} else if strings.HasPrefix(s, "VERSION_ID=") {
osVer = strings.Split(s, "=")[1]
osVer = strings.ReplaceAll(osVer, "\"", "")
}
}
osStr := strings.ReplaceAll(info, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ")
osName, osVersion := readOsReleaseFile()
if osName == "" {
osName = osInfo[3]
}
@ -72,7 +55,7 @@ func GetInfo(ctx context.Context) *Info {
Kernel: osInfo[0],
Platform: osInfo[2],
OS: osName,
OSVersion: osVer,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
@ -103,22 +86,12 @@ func _getInfo() string {
return out.String()
}
func _getReleaseInfo() string {
cmd := exec.Command("cat", "/etc/os-release")
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Warnf("geucwReleaseInfo: %s", err)
}
return out.String()
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo
si.GetSysInfo()
return si.Chassis.Serial, si.Product.Name, si.Product.Vendor
serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial
}
return serial, si.Product.Name, si.Product.Vendor
}

View File

@ -0,0 +1,38 @@
//go:build (linux && !android) || freebsd
package system
import (
"bufio"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
func readOsReleaseFile() (osName string, osVer string) {
file, err := os.Open("/etc/os-release")
if err != nil {
log.Warnf("failed to open file /etc/os-release: %s", err)
return "", ""
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "NAME=") {
osName = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
continue
}
if strings.HasPrefix(line, "VERSION_ID=") {
osVer = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
continue
}
if osName != "" && osVer != "" {
break
}
}
return
}

58
client/system/process.go Normal file
View File

@ -0,0 +1,58 @@
//go:build windows || (linux && !android) || (darwin && !ios) || freebsd
package system
import (
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processes, err := process.Processes()
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, p := range processes {
path, _ := p.Exe()
if path != "" {
processMap[path] = true
}
}
uniqueProcesses := make([]string, 0, len(processMap))
for p := range processMap {
uniqueProcesses = append(uniqueProcesses, p)
}
return uniqueProcesses, nil
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
if err != nil {
return nil, err
}
for i, path := range paths {
file := File{Path: path}
_, err := os.Stat(path)
file.Exist = !os.IsNotExist(err)
file.ProcessIsRunning = slices.Contains(runningProcesses, path)
files[i] = file
}
return files, nil
}

View File

@ -1,10 +1,11 @@
//go:build !(linux && 386)
//go:build !(linux && 386) && !freebsd
package main
import (
"context"
_ "embed"
"errors"
"flag"
"fmt"
"os"
@ -79,6 +80,7 @@ func main() {
log.Errorf("check PID file: %v", err)
return
}
client.setDefaultFonts()
systray.Run(client.onTrayReady, client.onTrayExit)
}
}
@ -125,44 +127,55 @@ type serviceClient struct {
icUpdateCloud []byte
// systray menu items
mStatus *systray.MenuItem
mUp *systray.MenuItem
mDown *systray.MenuItem
mAdminPanel *systray.MenuItem
mSettings *systray.MenuItem
mAbout *systray.MenuItem
mVersionUI *systray.MenuItem
mVersionDaemon *systray.MenuItem
mUpdate *systray.MenuItem
mQuit *systray.MenuItem
mRoutes *systray.MenuItem
mStatus *systray.MenuItem
mUp *systray.MenuItem
mDown *systray.MenuItem
mAdminPanel *systray.MenuItem
mSettings *systray.MenuItem
mAbout *systray.MenuItem
mVersionUI *systray.MenuItem
mVersionDaemon *systray.MenuItem
mUpdate *systray.MenuItem
mQuit *systray.MenuItem
mRoutes *systray.MenuItem
mAllowSSH *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mAdvancedSettings *systray.MenuItem
// application with main windows.
app fyne.App
wSettings fyne.Window
showSettings bool
sendNotification bool
app fyne.App
wSettings fyne.Window
showAdvancedSettings bool
sendNotification bool
// input elements for settings form
iMngURL *widget.Entry
iAdminURL *widget.Entry
iConfigFile *widget.Entry
iLogFile *widget.Entry
iPreSharedKey *widget.Entry
iMngURL *widget.Entry
iAdminURL *widget.Entry
iConfigFile *widget.Entry
iLogFile *widget.Entry
iPreSharedKey *widget.Entry
iInterfaceName *widget.Entry
iInterfacePort *widget.Entry
// switch elements for settings form
sRosenpassPermissive *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
preSharedKey string
adminURL string
managementURL string
preSharedKey string
adminURL string
RosenpassPermissive bool
interfaceName string
interfacePort int
connected bool
update *version.Update
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
showRoutes bool
wRoutes fyne.Window
showRoutes bool
wRoutes fyne.Window
}
// newServiceClient instance constructor
@ -175,9 +188,9 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo
app: a,
sendNotification: false,
showSettings: showSettings,
showRoutes: showRoutes,
update: version.NewUpdate(),
showAdvancedSettings: showSettings,
showRoutes: showRoutes,
update: version.NewUpdate(),
}
if runtime.GOOS == "windows" {
@ -215,8 +228,13 @@ func (s *serviceClient) showSettingsUI() {
s.iLogFile = widget.NewEntry()
s.iLogFile.Disable()
s.iPreSharedKey = widget.NewPasswordEntry()
s.iInterfaceName = widget.NewEntry()
s.iInterfacePort = widget.NewEntry()
s.sRosenpassPermissive = widget.NewCheck("Enable Rosenpass permissive mode", nil)
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 100))
s.wSettings.Resize(fyne.NewSize(600, 400))
s.wSettings.SetFixedSize(true)
s.getSrvConfig()
@ -239,6 +257,9 @@ func showErrorMSG(msg string) {
func (s *serviceClient) getSettingsForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
{Text: "Interface Name", Widget: s.iInterfaceName},
{Text: "Interface Port", Widget: s.iInterfacePort},
{Text: "Management URL", Widget: s.iMngURL},
{Text: "Admin URL", Widget: s.iAdminURL},
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
@ -255,45 +276,45 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
}
}
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
if err != nil {
dialog.ShowError(errors.New("Invalid interface port"), s.wSettings)
return
}
iAdminURL := strings.TrimSpace(s.iAdminURL.Text)
iMngURL := strings.TrimSpace(s.iMngURL.Text)
defer s.wSettings.Close()
// if management URL or Pre-shared key changed, we try to re-login with new settings.
if s.managementURL != s.iMngURL.Text || s.preSharedKey != s.iPreSharedKey.Text ||
s.adminURL != s.iAdminURL.Text {
s.managementURL = s.iMngURL.Text
// If the management URL, pre-shared key, admin URL, Rosenpass permissive mode,
// interface name, or interface port have changed, we attempt to re-login with the new settings.
if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text ||
s.adminURL != iAdminURL || s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) {
s.managementURL = iMngURL
s.preSharedKey = s.iPreSharedKey.Text
s.adminURL = s.iAdminURL.Text
client, err := s.getSrvClient(failFastTimeout)
if err != nil {
log.Errorf("get daemon client: %v", err)
return
}
s.adminURL = iAdminURL
loginRequest := proto.LoginRequest{
ManagementUrl: s.iMngURL.Text,
AdminURL: s.iAdminURL.Text,
ManagementUrl: iMngURL,
AdminURL: iAdminURL,
IsLinuxDesktopClient: runtime.GOOS == "linux",
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
InterfaceName: &s.iInterfaceName.Text,
WireguardPort: &port,
}
if s.iPreSharedKey.Text != "**********" {
loginRequest.OptionalPreSharedKey = &s.iPreSharedKey.Text
}
_, err = client.Login(s.ctx, &loginRequest)
if err != nil {
log.Errorf("login to management URL: %v", err)
if err := s.restartClient(&loginRequest); err != nil {
log.Errorf("restarting client connection: %v", err)
return
}
_, err = client.Up(s.ctx, &proto.UpRequest{})
if err != nil {
log.Errorf("login to management URL: %v", err)
return
}
}
s.wSettings.Close()
},
OnCancel: func() {
s.wSettings.Close()
@ -499,7 +520,14 @@ func (s *serviceClient) onTrayReady() {
s.mDown.Disable()
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel")
systray.AddSeparator()
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false)
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application")
s.loadSettings()
s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window")
s.mRoutes.Disable()
systray.AddSeparator()
@ -539,7 +567,7 @@ func (s *serviceClient) onTrayReady() {
case <-s.mAdminPanel.ClickedCh:
err = open.Run(s.adminURL)
case <-s.mUp.ClickedCh:
s.mUp.Disabled()
s.mUp.Disable()
go func() {
defer s.mUp.Enable()
err := s.menuUpClick()
@ -558,10 +586,40 @@ func (s *serviceClient) onTrayReady() {
return
}
}()
case <-s.mSettings.ClickedCh:
s.mSettings.Disable()
case <-s.mAllowSSH.ClickedCh:
if s.mAllowSSH.Checked() {
s.mAllowSSH.Uncheck()
} else {
s.mAllowSSH.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
return
}
case <-s.mAutoConnect.ClickedCh:
if s.mAutoConnect.Checked() {
s.mAutoConnect.Uncheck()
} else {
s.mAutoConnect.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
return
}
case <-s.mEnableRosenpass.ClickedCh:
if s.mEnableRosenpass.Checked() {
s.mEnableRosenpass.Uncheck()
} else {
s.mEnableRosenpass.Check()
}
if err := s.updateConfig(); err != nil {
log.Errorf("failed to update config: %v", err)
return
}
case <-s.mAdvancedSettings.ClickedCh:
s.mAdvancedSettings.Disable()
go func() {
defer s.mSettings.Enable()
defer s.mAdvancedSettings.Enable()
defer s.getSrvConfig()
s.runSelfCommand("settings", "true")
}()
@ -663,13 +721,23 @@ func (s *serviceClient) getSrvConfig() {
s.adminURL = cfg.AdminURL
}
s.preSharedKey = cfg.PreSharedKey
s.RosenpassPermissive = cfg.RosenpassPermissive
s.interfaceName = cfg.InterfaceName
s.interfacePort = int(cfg.WireguardPort)
if s.showSettings {
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
s.iAdminURL.SetText(s.adminURL)
s.iConfigFile.SetText(cfg.ConfigFile)
s.iLogFile.SetText(cfg.LogFile)
s.iPreSharedKey.SetText(cfg.PreSharedKey)
s.iInterfaceName.SetText(cfg.InterfaceName)
s.iInterfacePort.SetText(strconv.Itoa(int(cfg.WireguardPort)))
s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive)
if !cfg.RosenpassEnabled {
s.sRosenpassPermissive.Disable()
}
}
}
@ -704,6 +772,81 @@ func (s *serviceClient) onSessionExpire() {
}
}
// loadSettings loads the settings from the config file and updates the UI elements accordingly.
func (s *serviceClient) loadSettings() {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{})
if err != nil {
log.Errorf("get config settings from server: %v", err)
return
}
if cfg.ServerSSHAllowed {
s.mAllowSSH.Check()
} else {
s.mAllowSSH.Uncheck()
}
if cfg.DisableAutoConnect {
s.mAutoConnect.Uncheck()
} else {
s.mAutoConnect.Check()
}
if cfg.RosenpassEnabled {
s.mEnableRosenpass.Check()
} else {
s.mEnableRosenpass.Uncheck()
}
}
// updateConfig updates the configuration parameters
// based on the values selected in the settings window.
func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
loginRequest := proto.LoginRequest{
IsLinuxDesktopClient: runtime.GOOS == "linux",
ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled,
DisableAutoConnect: &disableAutoStart,
}
if err := s.restartClient(&loginRequest); err != nil {
log.Errorf("restarting client connection: %v", err)
return err
}
return nil
}
// restartClient restarts the client connection.
func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
client, err := s.getSrvClient(failFastTimeout)
if err != nil {
return err
}
_, err = client.Login(s.ctx, loginRequest)
if err != nil {
return err
}
_, err = client.Up(s.ctx, &proto.UpRequest{})
if err != nil {
return err
}
return nil
}
func openURL(url string) error {
var err error
switch runtime.GOOS {
@ -734,3 +877,88 @@ func checkPIDFile() error {
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}
func (s *serviceClient) setDefaultFonts() {
var (
defaultFontPath string
)
//TODO: Linux Multiple Language Support
switch runtime.GOOS {
case "darwin":
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
case "windows":
fontPath := s.getWindowsFontFilePath()
defaultFontPath = fontPath
}
_, err := os.Stat(defaultFontPath)
if err == nil {
os.Setenv("FYNE_FONT", defaultFontPath)
}
}
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
/*
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
*/
var (
fontFolder string = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
output, err := cmd.Output()
if err != nil {
log.Errorf("Failed to get Windows default language setting: %v", err)
fontPath = path.Join(fontFolder, fontMapping["default"])
return
}
defaultLanguage := strings.TrimSpace(string(output))
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
return
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
fontPath = path.Join(fontFolder, font)
} else {
fontPath = path.Join(fontFolder, fontMapping["default"])
}
return
}

View File

@ -1,9 +1,10 @@
//go:build !(linux && 386)
//go:build !(linux && 386) && !freebsd
package main
import (
"fmt"
"sort"
"strings"
"time"
@ -17,28 +18,57 @@ import (
"github.com/netbirdio/netbird/client/proto"
)
const (
allRoutesText = "All routes"
overlappingRoutesText = "Overlapping routes"
exitNodeRoutesText = "Exit-node routes"
allRoutes filter = "all"
overlappingRoutes filter = "overlapping"
exitNodeRoutes filter = "exit-node"
getClientFMT = "get client: %v"
)
type filter string
func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(2))
go s.updateRoutes(grid)
allGrid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(allGrid, allRoutes)
overlappingGrid := container.New(layout.NewGridLayout(3))
exitNodeGrid := container.New(layout.NewGridLayout(3))
routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid)
tabs := container.NewAppTabs(
container.NewTabItem(allRoutesText, allGrid),
container.NewTabItem(overlappingRoutesText, overlappingGrid),
container.NewTabItem(exitNodeRoutesText, exitNodeGrid),
)
tabs.OnSelected = func(item *container.TabItem) {
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}
tabs.OnUnselected = func(item *container.TabItem) {
grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
grid.Objects = nil
}
routeCheckContainer.Add(tabs)
scrollContainer := container.NewVScroll(routeCheckContainer)
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
buttonBox := container.NewHBox(
layout.NewSpacer(),
widget.NewButton("Refresh", func() {
s.updateRoutes(grid)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}),
widget.NewButton("Select all", func() {
s.selectAllRoutes()
s.updateRoutes(grid)
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.selectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}),
widget.NewButton("Deselect All", func() {
s.deselectAllRoutes()
s.updateRoutes(grid)
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
s.deselectAllFilteredRoutes(f)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
}),
layout.NewSpacer(),
)
@ -48,27 +78,31 @@ func (s *serviceClient) showRoutesUI() {
s.wRoutes.SetContent(content)
s.wRoutes.Show()
s.startAutoRefresh(5*time.Second, grid)
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
}
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
return
}
func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
grid.Objects = nil
grid.Refresh()
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
grid.Add(idHeader)
grid.Add(networkHeader)
for _, route := range routes {
grid.Add(resolvedIPsHeader)
filteredRoutes, err := s.getFilteredRoutes(f)
if err != nil {
return
}
sortRoutesByIDs(filteredRoutes)
for _, route := range filteredRoutes {
r := route
checkBox := widget.NewCheck(r.ID, func(checked bool) {
checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
s.selectRoute(r.ID, checked)
})
checkBox.Checked = route.Selected
@ -76,16 +110,106 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
checkBox.Refresh()
grid.Add(checkBox)
grid.Add(widget.NewLabel(r.Network))
network := r.GetNetwork()
domains := r.GetDomains()
if len(domains) == 0 {
grid.Add(widget.NewLabel(network))
grid.Add(widget.NewLabel(""))
continue
}
// our selectors are only for display
noopFunc := func(_ string) {
// do nothing
}
domainsSelector := widget.NewSelect(domains, noopFunc)
domainsSelector.Selected = domains[0]
grid.Add(domainsSelector)
var resolvedIPsList []string
for _, domain := range domains {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
if len(resolvedIPsList) == 0 {
grid.Add(widget.NewLabel(""))
continue
}
// TODO: limit width within the selector display
resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc)
resolvedIPsSelector.Selected = resolvedIPsList[0]
resolvedIPsSelector.Resize(fyne.NewSize(100, 100))
grid.Add(resolvedIPsSelector)
}
s.wRoutes.Content().Refresh()
grid.Refresh()
}
func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) {
routes, err := s.fetchRoutes()
if err != nil {
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
return nil, err
}
switch f {
case overlappingRoutes:
return getOverlappingRoutes(routes), nil
case exitNodeRoutes:
return getExitNodeRoutes(routes), nil
default:
}
return routes, nil
}
func getOverlappingRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
existingRange := make(map[string][]*proto.Route)
for _, route := range routes {
if len(route.Domains) > 0 {
continue
}
if r, exists := existingRange[route.GetNetwork()]; exists {
r = append(r, route)
existingRange[route.GetNetwork()] = r
} else {
existingRange[route.GetNetwork()] = []*proto.Route{route}
}
}
for _, r := range existingRange {
if len(r) > 1 {
filteredRoutes = append(filteredRoutes, r...)
}
}
return filteredRoutes
}
func getExitNodeRoutes(routes []*proto.Route) []*proto.Route {
var filteredRoutes []*proto.Route
for _, route := range routes {
if route.Network == "0.0.0.0/0" {
filteredRoutes = append(filteredRoutes, route)
}
}
return filteredRoutes
}
func sortRoutesByIDs(routes []*proto.Route) {
sort.Slice(routes, func(i, j int) bool {
return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID())
})
}
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return nil, fmt.Errorf("get client: %v", err)
return nil, fmt.Errorf(getClientFMT, err)
}
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
@ -99,8 +223,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
func (s *serviceClient) selectRoute(id string, checked bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
s.showError(fmt.Errorf("get client: %v", err))
log.Errorf(getClientFMT, err)
s.showError(fmt.Errorf(getClientFMT, err))
return
}
@ -126,16 +250,14 @@ func (s *serviceClient) selectRoute(id string, checked bool) {
}
}
func (s *serviceClient) selectAllRoutes() {
func (s *serviceClient) selectAllFilteredRoutes(f filter) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
log.Errorf(getClientFMT, err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
req := s.getRoutesRequest(f, true)
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to select all routes: %v", err)
s.showError(fmt.Errorf("failed to select all routes: %v", err))
@ -145,16 +267,14 @@ func (s *serviceClient) selectAllRoutes() {
log.Debug("All routes selected")
}
func (s *serviceClient) deselectAllRoutes() {
func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
log.Errorf(getClientFMT, err)
return
}
req := &proto.SelectRoutesRequest{
All: true,
}
req := s.getRoutesRequest(f, false)
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
log.Errorf("failed to deselect all routes: %v", err)
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
@ -164,17 +284,34 @@ func (s *serviceClient) deselectAllRoutes() {
log.Debug("All routes deselected")
}
func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest {
req := &proto.SelectRoutesRequest{}
if f == allRoutes {
req.All = true
} else {
routes, err := s.getFilteredRoutes(f)
if err != nil {
return nil
}
for _, route := range routes {
req.RouteIDs = append(req.RouteIDs, route.GetID())
}
req.Append = appendRoute
}
return req
}
func (s *serviceClient) showError(err error) {
wrappedMessage := wrapText(err.Error(), 50)
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
}
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
s.updateRoutes(grid)
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
}
}()
@ -183,6 +320,23 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Cont
})
}
func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
s.wRoutes.Content().Refresh()
s.updateRoutes(grid, f)
}
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
switch tabs.Selected().Text {
case overlappingRoutesText:
return overlappingGrid, overlappingRoutes
case exitNodeRoutesText:
return exitNodesGrid, exitNodeRoutes
default:
return allGrid, allRoutes
}
}
// wrapText inserts newlines into the text to ensure that each line is
// no longer than 'lineLength' runes.
func wrapText(text string, lineLength int) string {

Some files were not shown because too many files have changed in this diff Show More