mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 10:21:10 +01:00
- Revert typos in turnCfg string
- merge main
This commit is contained in:
commit
b3715b5fad
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['jsonfile', 'sqlite']
|
||||
store: ['sqlite']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
39
.github/workflows/golang-test-freebsd.yml
vendored
Normal file
39
.github/workflows/golang-test-freebsd.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
|
||||
name: Test Code FreeBSD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
prepare: |
|
||||
pkg install -y curl
|
||||
pkg install -y git
|
||||
|
||||
run: |
|
||||
set -x
|
||||
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
|
||||
tar zxf go.tar.gz
|
||||
mv go /usr/local/go
|
||||
ln -s /usr/local/go/bin/go /usr/local/bin/go
|
||||
go mod tidy
|
||||
go test -timeout 5m -p 1 ./iface/...
|
||||
go test -timeout 5m -p 1 ./client/...
|
||||
cd client
|
||||
go build .
|
||||
cd ..
|
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'jsonfile', 'sqlite', 'postgres']
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
@ -86,7 +86,10 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
@ -108,6 +111,9 @@ jobs:
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -173,7 +173,7 @@ jobs:
|
||||
retention-days: 3
|
||||
|
||||
release_ui_darwin:
|
||||
runs-on: macos-11
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
|
59
.github/workflows/test-infrastructure-files.yml
vendored
59
.github/workflows/test-infrastructure-files.yml
vendored
@ -178,34 +178,79 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: run script
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
- name: test Caddy file gen
|
||||
- name: test Caddy file gen postgres
|
||||
run: test -f Caddyfile
|
||||
- name: test docker-compose file gen
|
||||
|
||||
- name: test docker-compose file gen postgres
|
||||
run: test -f docker-compose.yml
|
||||
- name: test management.json file gen
|
||||
|
||||
- name: test management.json file gen postgres
|
||||
run: test -f management.json
|
||||
- name: test turnserver.conf file gen
|
||||
|
||||
- name: test turnserver.conf file gen postgres
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
- name: test zitadel.env file gen
|
||||
|
||||
- name: test zitadel.env file gen postgres
|
||||
run: test -f zitadel.env
|
||||
- name: test dashboard.env file gen
|
||||
|
||||
- name: test dashboard.env file gen postgres
|
||||
run: test -f dashboard.env
|
||||
|
||||
- name: test zdb.env file gen postgres
|
||||
run: test -f zdb.env
|
||||
|
||||
- name: Postgres run cleanup
|
||||
run: |
|
||||
docker-compose down --volumes --rmi all
|
||||
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
|
||||
|
||||
- name: run script with Zitadel CockroachDB
|
||||
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
env:
|
||||
NETBIRD_DOMAIN: use-ip
|
||||
ZITADEL_DATABASE: cockroach
|
||||
|
||||
- name: test Caddy file gen CockroachDB
|
||||
run: test -f Caddyfile
|
||||
|
||||
- name: test docker-compose file gen CockroachDB
|
||||
run: test -f docker-compose.yml
|
||||
|
||||
- name: test management.json file gen CockroachDB
|
||||
run: test -f management.json
|
||||
|
||||
- name: test turnserver.conf file gen CockroachDB
|
||||
run: |
|
||||
set -x
|
||||
test -f turnserver.conf
|
||||
grep external-ip turnserver.conf
|
||||
|
||||
- name: test zitadel.env file gen CockroachDB
|
||||
run: test -f zitadel.env
|
||||
|
||||
- name: test dashboard.env file gen CockroachDB
|
||||
run: test -f dashboard.env
|
||||
|
||||
test-download-geolite2-script:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install jq
|
||||
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: test script
|
||||
run: bash -x infrastructure_files/download-geolite2.sh
|
||||
|
||||
- name: test mmdb file exists
|
||||
run: test -f GeoLite2-City.mmdb
|
||||
|
||||
- name: test geonames file exists
|
||||
run: test -f geonames.db
|
||||
|
@ -3,8 +3,10 @@ builds:
|
||||
- id: netbird-ui-darwin
|
||||
dir: client/ui
|
||||
binary: netbird-ui
|
||||
env: [CGO_ENABLED=1]
|
||||
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- MACOSX_DEPLOYMENT_TARGET=11.0
|
||||
- MACOS_DEPLOYMENT_TARGET=11.0
|
||||
goos:
|
||||
- darwin
|
||||
goarch:
|
||||
|
@ -1,4 +1,4 @@
|
||||
FROM alpine:3.18.5
|
||||
FROM alpine:3.19
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
|
@ -36,6 +36,7 @@ const (
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -68,7 +69,9 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
rootCmd = &cobra.Command{
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
Long: "",
|
||||
|
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@ -66,18 +67,60 @@ func routesList(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
selectedStatus := "Not Selected"
|
||||
if route.GetSelected() {
|
||||
selectedStatus = "Selected"
|
||||
}
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
printRoutes(cmd, resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
printRoute(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
if len(domains) > 0 {
|
||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||
} else {
|
||||
printNetworkRoute(cmd, route, selectedStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func getSelectedStatus(route *proto.Route) string {
|
||||
if route.GetSelected() {
|
||||
return "Selected"
|
||||
}
|
||||
return "Not Selected"
|
||||
}
|
||||
|
||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||
resolvedIPs := route.GetResolvedIPs()
|
||||
|
||||
if len(resolvedIPs) > 0 {
|
||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||
} else {
|
||||
cmd.Printf(" Resolved IPs: -\n")
|
||||
}
|
||||
}
|
||||
|
||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := resolvedIPs[domain]; exists {
|
||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
@ -807,11 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
}
|
||||
|
||||
@ -847,12 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
@ -7,6 +7,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -53,7 +56,10 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
|
||||
srv, err := sig.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
panic(err)
|
||||
@ -70,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -81,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -102,7 +108,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
func startClientDaemon(
|
||||
t *testing.T, ctx context.Context, managementURL, configPath string,
|
||||
t *testing.T, ctx context.Context, _, configPath string,
|
||||
) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
@ -7,11 +7,13 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@ -40,8 +42,12 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@ -137,6 +143,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@ -237,6 +247,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
30
client/errors/errors.go
Normal file
30
client/errors/errors.go
Normal file
@ -0,0 +1,30 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
||||
func FormatErrorOrNil(err *multierror.Error) error {
|
||||
if err != nil {
|
||||
err.ErrorFormat = formatError
|
||||
}
|
||||
return err.ErrorOrNil()
|
||||
}
|
@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string,
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Insert(table, chain, 1, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
|
||||
@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
|
||||
}
|
||||
|
||||
// Add the loopback return rule to the NAT chain
|
||||
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
|
||||
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
|
||||
}
|
||||
|
||||
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
|
||||
@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNATRule appends an iptables rule pair to the nat chain
|
||||
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, pair.Source, pair.Destination)
|
||||
existingRule, found := i.rules[ruleKey]
|
||||
if found {
|
||||
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(i.rules, ruleKey)
|
||||
}
|
||||
|
||||
// inserting after loopback ignore rule
|
||||
err := i.iptablesClient.Insert(table, chain, 2, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
i.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// genRuleSpec generates rule specification
|
||||
func genRuleSpec(jump, source, destination string) []string {
|
||||
return []string{"-s", source, "-d", destination, "-j", jump}
|
||||
|
@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.InsertRoutingRules(pair)
|
||||
return m.router.AddRoutingRules(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
|
||||
|
@ -22,6 +22,8 @@ const (
|
||||
|
||||
userDataAcceptForwardRuleSrc = "frwacceptsrc"
|
||||
userDataAcceptForwardRuleDst = "frwacceptdst"
|
||||
|
||||
loopbackInterface = "lo\x00"
|
||||
)
|
||||
|
||||
// some presets for building nftable rules
|
||||
@ -126,6 +128,22 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Add RETURN rule for loopback interface
|
||||
loRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte(loopbackInterface),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictReturn},
|
||||
},
|
||||
}
|
||||
r.conn.InsertRule(loRule)
|
||||
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
@ -138,28 +156,28 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
|
||||
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
|
||||
err := r.refreshRulesMap()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// insertRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
// addRoutingRule inserts a nftable rule to the conn client flush queue
|
||||
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainName],
|
||||
Exprs: expression,
|
||||
|
@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.InputPair)
|
||||
err = manager.AddRoutingRules(testCase.InputPair)
|
||||
defer func() {
|
||||
_ = manager.RemoveRoutingRules(testCase.InputPair)
|
||||
}()
|
||||
|
@ -6,13 +6,16 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
@ -53,6 +56,7 @@ type ConfigInput struct {
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@ -64,7 +68,7 @@ type Config struct {
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor bool
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
@ -95,6 +99,9 @@ type Config struct {
|
||||
// DisableAutoConnect determines whether the client should not start with the service
|
||||
// it's set to false by default due to backwards compatibility
|
||||
DisableAutoConnect bool
|
||||
|
||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@ -304,12 +311,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor {
|
||||
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
|
||||
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||
config.NetworkMonitor = *input.NetworkMonitor
|
||||
config.NetworkMonitor = input.NetworkMonitor
|
||||
updated = true
|
||||
}
|
||||
|
||||
if config.NetworkMonitor == nil {
|
||||
// enable network monitoring by default on windows and darwin clients
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
|
||||
enabled := true
|
||||
config.NetworkMonitor = &enabled
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
|
||||
log.Infof("updating custom DNS address %#v (old value %#v)",
|
||||
string(input.CustomDNSAddress), config.CustomDNSAddress)
|
||||
@ -357,6 +373,18 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
config.DNSRouteInterval = *input.DNSRouteInterval
|
||||
updated = true
|
||||
} else if config.DNSRouteInterval == 0 {
|
||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||
updated = true
|
||||
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
@ -264,8 +264,10 @@ func (c *ConnectClient) run(
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
err = c.engine.Start()
|
||||
@ -342,6 +344,10 @@ func (c *ConnectClient) Engine() *Engine {
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
@ -349,13 +355,14 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: config.NetworkMonitor,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
|
6
client/internal/dns/consts_freebsd.go
Normal file
6
client/internal/dns/consts_freebsd.go
Normal file
@ -0,0 +1,6 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
||||
)
|
8
client/internal/dns/consts_linux.go
Normal file
8
client/internal/dns/consts_linux.go
Normal file
@ -0,0 +1,8 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
@ -116,16 +116,10 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
var value string
|
||||
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
|
||||
if err == nil {
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return systemdManager, nil
|
||||
}
|
||||
}
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -39,6 +39,10 @@ func (w *mocWGIface) Address() iface.WGAddress {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) ToInterface() *net.Interface {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
@ -261,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -339,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
@ -797,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
20
client/internal/dns/systemd_freebsd.go
Normal file
20
client/internal/dns/systemd_freebsd.go
Normal file
@ -0,0 +1,20 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errNotImplemented = errors.New("not implemented")
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
return false
|
||||
}
|
@ -242,3 +242,25 @@ func getSystemdDbusProperty(property string, store any) error {
|
||||
|
||||
return v.Store(store)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
return false
|
||||
}
|
||||
|
||||
var value string
|
||||
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@ -14,11 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
||||
|
||||
func CheckUncleanShutdown(wgIface string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
@ -78,6 +78,11 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}()
|
||||
|
||||
log.WithField("question", r.Question[0]).Trace("received an upstream question")
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
|
@ -2,12 +2,17 @@
|
||||
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -30,12 +31,15 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
auth "github.com/netbirdio/netbird/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
@ -43,6 +47,7 @@ import (
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@ -93,6 +98,8 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@ -105,8 +112,8 @@ type Engine struct {
|
||||
// peerConns is a map that holds all the peers that are known to this peer
|
||||
peerConns map[string]*peer.Conn
|
||||
|
||||
beforePeerHook peer.BeforeAddPeerHookFunc
|
||||
afterPeerHook peer.AfterRemovePeerHookFunc
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
|
||||
// rpManager is a Rosenpass manager
|
||||
rpManager *rosenpass.Manager
|
||||
@ -159,6 +166,9 @@ type Engine struct {
|
||||
relayProbe *Probe
|
||||
wgProbe *Probe
|
||||
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
}
|
||||
|
||||
@ -178,6 +188,7 @@ func NewEngine(
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
return NewEngineWithProbes(
|
||||
clientCtx,
|
||||
@ -192,6 +203,7 @@ func NewEngine(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
checks,
|
||||
)
|
||||
}
|
||||
|
||||
@ -209,6 +221,7 @@ func NewEngineWithProbes(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
return &Engine{
|
||||
clientCtx: clientCtx,
|
||||
@ -230,6 +243,7 @@ func NewEngineWithProbes(
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
wgProbe: wgProbe,
|
||||
checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
@ -277,8 +291,6 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
log.Errorf("failed creating wireguard interface instance %s: [%s]", e.config.WgIfaceName, err)
|
||||
@ -286,6 +298,9 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@ -310,7 +325,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
@ -498,6 +513,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if update.GetNetworkMap() != nil {
|
||||
// only apply new changes and ignore old ones
|
||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||
@ -505,7 +524,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
// if checks are equal, we skip the update
|
||||
if isChecksEqual(e.checks, checks) {
|
||||
return nil
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -521,8 +560,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
} else {
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on Windows is not supported")
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
@ -595,7 +634,14 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
err := e.mgmClient.Sync(e.ctx, e.handleSync)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
@ -662,6 +708,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
@ -703,19 +763,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
protoRoutes := networkMap.GetRoutes()
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
@ -743,15 +790,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||
var prefix netip.Prefix
|
||||
if len(protoRoute.Domains) == 0 {
|
||||
var err error
|
||||
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
|
||||
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
Peer: protoRoute.Peer,
|
||||
Metric: int(protoRoute.Metric),
|
||||
Masquerade: protoRoute.Masquerade,
|
||||
KeepRoute: protoRoute.KeepRoute,
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
@ -1105,7 +1161,8 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
netMap, err := e.mgmClient.GetNetworkMap()
|
||||
info := system.GetInfo(e.ctx)
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -1134,7 +1191,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
default:
|
||||
}
|
||||
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs)
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
|
||||
}
|
||||
|
||||
func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
@ -1309,6 +1366,15 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
|
||||
}
|
||||
|
||||
func (e *Engine) restartEngine() {
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
@ -1317,17 +1383,54 @@ func (e *Engine) startNetworkMonitor() {
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
var mu sync.Mutex
|
||||
var debounceTimer *time.Timer
|
||||
|
||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
||||
// a network change is detected, or an error occurs on start up
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
// This function is called when a network change is detected
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if debounceTimer != nil {
|
||||
debounceTimer.Stop()
|
||||
}
|
||||
|
||||
// Set a new timer to debounce rapid network changes
|
||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
||||
// This function is called after the debounce period
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
})
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
||||
var vpnRoutes []netip.Prefix
|
||||
for _, routes := range e.GetClientRoutes() {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
vpnRoutes = append(vpnRoutes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
|
||||
return true, prefix, nil
|
||||
}
|
||||
|
||||
return false, netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@ -58,9 +59,9 @@ var (
|
||||
)
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping TestEngine_SSH on Windows")
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping TestEngine_SSH")
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
@ -80,7 +81,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
},
|
||||
MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@ -176,7 +177,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//time.Sleep(250 * time.Millisecond)
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
@ -215,16 +216,16 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@ -397,7 +398,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
@ -412,7 +413,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@ -572,13 +573,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@ -743,14 +744,14 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil)
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
@ -816,13 +817,13 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||
defer cancel()
|
||||
|
||||
sigServer, signalAddr, err := startSignal()
|
||||
sigServer, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer sigServer.Stop()
|
||||
mgmtServer, mgmtAddr, err := startManagement(dir)
|
||||
mgmtServer, mgmtAddr, err := startManagement(t, dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -1015,12 +1016,14 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
@ -1028,7 +1031,9 @@ func startSignal() (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
@ -1039,7 +1044,9 @@ func startSignal() (*grpc.Server, string, error) {
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
@ -1056,23 +1063,25 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, _, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -5,8 +5,6 @@ package networkmonitor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
@ -14,10 +12,10 @@ import (
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
@ -47,24 +45,6 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
|
||||
// handle interface state changes
|
||||
case unix.RTM_IFINFO:
|
||||
ifinfo, err := parseInterfaceMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("Network monitor: error parsing interface message: %v", err)
|
||||
continue
|
||||
}
|
||||
if msg.Flags&unix.IFF_UP != 0 {
|
||||
continue
|
||||
}
|
||||
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
|
||||
go callback()
|
||||
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
@ -86,7 +66,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
case unix.RTM_DELETE:
|
||||
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
}
|
||||
@ -96,25 +76,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
}
|
||||
}
|
||||
|
||||
func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeInterface, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.InterfaceMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
@ -129,5 +91,5 @@ func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return routemanager.MsgToRoute(msg)
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
@ -6,14 +6,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
@ -29,23 +28,22 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 netip.Addr
|
||||
var intf4, intf6 *net.Interface
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
@ -65,7 +63,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
|
||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
|
@ -6,27 +6,22 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
if intfv4 == nil && intfv6 == nil {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
|
||||
linkChan := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to link updates: %v", err)
|
||||
}
|
||||
|
||||
routeChan := make(chan netlink.RouteUpdate)
|
||||
if err := netlink.RouteSubscribe(routeChan, done); err != nil {
|
||||
return fmt.Errorf("subscribe to route updates: %v", err)
|
||||
@ -38,25 +33,6 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
|
||||
// handle interface state changes
|
||||
case update := <-linkChan:
|
||||
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_NEWLINK:
|
||||
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handle route changes
|
||||
case route := <-routeChan:
|
||||
// default route and main table
|
||||
@ -70,7 +46,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
|
||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
|
@ -5,11 +5,12 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -25,20 +26,16 @@ const (
|
||||
|
||||
const interval = 10 * time.Second
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
var neighborv4, neighborv6 *routemanager.Neighbor
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
var neighborv4, neighborv6 *systemops.Neighbor
|
||||
{
|
||||
initialNeighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
if n, ok := initialNeighbors[nexthopv4]; ok {
|
||||
neighborv4 = &n
|
||||
}
|
||||
if n, ok := initialNeighbors[nexthopv6]; ok {
|
||||
neighborv6 = &n
|
||||
}
|
||||
neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
|
||||
neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
|
||||
}
|
||||
log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
|
||||
|
||||
@ -50,7 +47,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
case <-ticker.C:
|
||||
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
|
||||
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
@ -58,13 +55,21 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
}
|
||||
}
|
||||
|
||||
func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
|
||||
if n, ok := initialNeighbors[nexthop.IP]; ok &&
|
||||
n.State != unreachable &&
|
||||
n.State != incomplete &&
|
||||
n.State != tbd {
|
||||
return &n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func changed(
|
||||
nexthopv4 netip.Addr,
|
||||
intfv4 *net.Interface,
|
||||
neighborv4 *routemanager.Neighbor,
|
||||
nexthopv6 netip.Addr,
|
||||
intfv6 *net.Interface,
|
||||
neighborv6 *routemanager.Neighbor,
|
||||
nexthopv4 systemops.Nexthop,
|
||||
neighborv4 *systemops.Neighbor,
|
||||
nexthopv6 systemops.Nexthop,
|
||||
neighborv6 *systemops.Neighbor,
|
||||
) bool {
|
||||
neighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
@ -81,7 +86,7 @@ func changed(
|
||||
return false
|
||||
}
|
||||
|
||||
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
|
||||
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -89,44 +94,74 @@ func changed(
|
||||
}
|
||||
|
||||
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
|
||||
if !nexthop.IsValid() {
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
|
||||
if !nexthop.IP.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
var unspec netip.Prefix
|
||||
if nexthop.Is6() {
|
||||
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
} else {
|
||||
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
unspec := getUnspecifiedPrefix(nexthop.IP)
|
||||
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
|
||||
|
||||
if r, ok := routes[unspec]; ok {
|
||||
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
|
||||
intf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
intf = r.Interface.Name
|
||||
}
|
||||
log.Infof("network monitor: default route changed: %s via %s (%s)", r.Destination, r.Nexthop, intf)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
log.Infof("network monitor: default route is gone")
|
||||
log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
|
||||
|
||||
if !foundMatchingRoute {
|
||||
logRouteChange(nexthop.IP, intf)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
|
||||
func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
|
||||
if ip.Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
|
||||
var defaultRoutes []string
|
||||
foundMatchingRoute := false
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Destination == unspec {
|
||||
routeInfo := formatRouteInfo(r)
|
||||
defaultRoutes = append(defaultRoutes, routeInfo)
|
||||
|
||||
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
|
||||
foundMatchingRoute = true
|
||||
log.Debugf("network monitor: found matching default route: %s", routeInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultRoutes, foundMatchingRoute
|
||||
}
|
||||
|
||||
func formatRouteInfo(r systemops.Route) string {
|
||||
newIntf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
newIntf = r.Interface.Name
|
||||
}
|
||||
return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
|
||||
}
|
||||
|
||||
func logRouteChange(ip netip.Addr, intf *net.Interface) {
|
||||
oldIntf := "<nil>"
|
||||
if intf != nil {
|
||||
oldIntf = intf.Name
|
||||
}
|
||||
log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
|
||||
if neighbor == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||
if n, ok := neighbors[nexthop]; ok {
|
||||
if n.State != reachable && n.State != permanent {
|
||||
if n, ok := neighbors[nexthop.IP]; ok {
|
||||
if n.State == unreachable || n.State == incomplete {
|
||||
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||
return true
|
||||
} else if n.InterfaceIndex != neighbor.InterfaceIndex {
|
||||
@ -150,13 +185,13 @@ func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighb
|
||||
return false
|
||||
}
|
||||
|
||||
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||
entries, err := routemanager.GetNeighbors()
|
||||
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
entries, err := systemops.GetNeighbors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
|
||||
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
|
||||
for _, entry := range entries {
|
||||
neighbours[entry.IPAddress] = entry
|
||||
}
|
||||
@ -164,18 +199,13 @@ func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||
return neighbours, nil
|
||||
}
|
||||
|
||||
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
|
||||
entries, err := routemanager.GetRoutes()
|
||||
func getRoutes() ([]systemops.Route, error) {
|
||||
entries, err := systemops.GetRoutes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
|
||||
for _, entry := range entries {
|
||||
routes[entry.Destination] = entry
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func stateFromInt(state uint8) string {
|
||||
|
@ -62,9 +62,6 @@ type ConnConfig struct {
|
||||
ICEConfig ICEConfig
|
||||
}
|
||||
|
||||
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
|
||||
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
|
||||
|
||||
type WorkerCallbacks struct {
|
||||
OnRelayReadyCallback func(info RelayConnInfo)
|
||||
OnRelayStatusChanged func(ConnStatus)
|
||||
@ -99,8 +96,8 @@ type Conn struct {
|
||||
workerRelay *WorkerRelay
|
||||
|
||||
connID nbnet.ConnectionID
|
||||
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
|
||||
endpointRelay *net.UDPAddr
|
||||
|
||||
@ -266,11 +263,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||
}
|
||||
|
||||
|
@ -46,7 +46,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@ -61,7 +61,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@ -134,7 +134,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
@ -172,7 +172,7 @@ func TestConn_Status(t *testing.T) {
|
||||
func TestConn_Switch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
wgProxyFactory := wgproxy.NewFactory(ctx, connConf.LocalWgPort)
|
||||
wgProxyFactory := wgproxy.NewFactory(ctx, false, connConf.LocalWgPort)
|
||||
connConfAlice := ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
|
@ -2,14 +2,17 @@ package peer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// State contains the latest state of a peer
|
||||
@ -37,25 +40,25 @@ type State struct {
|
||||
// AddRoute add a single route to routes map
|
||||
func (s *State) AddRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
if s.routes == nil {
|
||||
s.routes = make(map[string]struct{})
|
||||
}
|
||||
s.routes[network] = struct{}{}
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// SetRoutes set state routes
|
||||
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
s.routes = routes
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// DeleteRoute removes a route from the network amp
|
||||
func (s *State) DeleteRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
delete(s.routes, network)
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// GetRoutes return routes map
|
||||
@ -117,22 +120,23 @@ type FullStatus struct {
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
||||
|
||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||
@ -143,11 +147,12 @@ type Status struct {
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,7 +193,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return State{}, errors.New("peer not found")
|
||||
return State{}, iface.ErrPeerNotFound
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
@ -429,6 +434,18 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.resolvedDomainsStates[domain] = prefixes
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
delete(d.resolvedDomainsStates, domain)
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
@ -493,6 +510,12 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
||||
return d.nsGroupStates
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.mux.Lock()
|
||||
|
@ -3,19 +3,20 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const minRangeBits = 7
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
@ -28,33 +29,42 @@ type routesUpdate struct {
|
||||
routes []*route.Route
|
||||
}
|
||||
|
||||
// RouteHandler defines the interface for handling routes
|
||||
type RouteHandler interface {
|
||||
String() string
|
||||
AddRoute(ctx context.Context) error
|
||||
RemoveRoute() error
|
||||
AddAllowedIPs(peerKey string) error
|
||||
RemoveAllowedIPs() error
|
||||
}
|
||||
|
||||
type clientNetwork struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{}
|
||||
chosenRoute *route.Route
|
||||
network netip.Prefix
|
||||
currentChosen *route.Route
|
||||
handler RouteHandler
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
network: network,
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
||||
}
|
||||
return client
|
||||
}
|
||||
@ -86,8 +96,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
// * Metric: Routes with lower metrics (better) are prioritized.
|
||||
// * Non-relayed: Routes without relays are preferred.
|
||||
// * Direct connections: Routes with direct peer connections are favored.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
// * Latency: Routes with lower latency are prioritized.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||
@ -96,8 +106,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
currScore := float64(0)
|
||||
|
||||
currID := route.ID("")
|
||||
if c.chosenRoute != nil {
|
||||
currID = c.chosenRoute.ID
|
||||
if c.currentChosen != nil {
|
||||
currID = c.currentChosen.ID
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
@ -151,18 +161,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
|
||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
|
||||
case chosen != currID:
|
||||
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
return currID
|
||||
}
|
||||
var p string
|
||||
if rt := c.routes[chosen]; rt != nil {
|
||||
p = rt.Peer
|
||||
}
|
||||
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
|
||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
|
||||
}
|
||||
|
||||
return chosen
|
||||
@ -196,98 +206,103 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer state: %v", err)
|
||||
}
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
|
||||
state.DeleteRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if state.ConnStatus != peer.StatusConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||
}
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route: %v", err)
|
||||
}
|
||||
if c.currentChosen == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
if chosen == "" {
|
||||
if newChosenID == "" {
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
return fmt.Errorf("remove route from peer and system: %v", err)
|
||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.chosenRoute = nil
|
||||
c.currentChosen = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the chosen route is the same as the current route, do nothing
|
||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||
return nil
|
||||
}
|
||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
||||
c.currentChosen.IsEqual(c.routes[newChosenID]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil {
|
||||
// If a previous route exists, remove it from the peer
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route from peer: %v", err)
|
||||
if c.currentChosen == nil {
|
||||
// If they were not previously assigned to another peer, add routes to the system first
|
||||
if err := c.handler.AddRoute(c.ctx); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
} else {
|
||||
// otherwise add the route to the system
|
||||
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
|
||||
c.chosenRoute = c.routes[chosen]
|
||||
c.currentChosen = c.routes[newChosenID]
|
||||
|
||||
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
} else {
|
||||
state.AddRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
|
||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
c.addStateRoute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
@ -318,24 +333,23 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("stopping watcher for network %s", c.network)
|
||||
err := c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
|
||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("Received a routes update with smaller serial number, ignoring it")
|
||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Received a new client network route update for %s", c.network)
|
||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
||||
|
||||
c.handleUpdate(update)
|
||||
|
||||
@ -343,7 +357,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
@ -351,14 +365,9 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getAsInterface() *net.Interface {
|
||||
intf, err := net.InterfaceByName(c.wgInterface.Name())
|
||||
if err != nil {
|
||||
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
|
||||
intf = &net.Interface{
|
||||
Name: c.wgInterface.Name(),
|
||||
}
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
||||
}
|
||||
|
||||
return intf
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -340,9 +341,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
|
||||
// create new clientNetwork
|
||||
client := &clientNetwork{
|
||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
routes: tc.existingRoutes,
|
||||
chosenRoute: currentRoute,
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
}
|
||||
|
||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||
|
378
client/internal/routemanager/dynamic/route.go
Normal file
378
client/internal/routemanager/dynamic/route.go
Normal file
@ -0,0 +1,378 @@
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInterval = time.Minute
|
||||
|
||||
minInterval = 2 * time.Second
|
||||
failureInterval = 5 * time.Second
|
||||
|
||||
addAllowedIP = "add allowed IP %s: %w"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type resolveResult struct {
|
||||
domain domain.Domain
|
||||
prefix netip.Prefix
|
||||
err error
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
interval time.Duration
|
||||
dynamicDomains domainMap
|
||||
mu sync.Mutex
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) String() string {
|
||||
s, err := r.route.Domains.String()
|
||||
if err != nil {
|
||||
return r.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(ctx context.Context) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
ctx, r.cancel = context.WithCancel(ctx)
|
||||
|
||||
go r.startResolver(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
|
||||
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
|
||||
func (r *Route) RemoveRoute() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range r.dynamicDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
r.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
r.dynamicDomains = domainMap{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
r.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.currentPeerKey = ""
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) startResolver(ctx context.Context) {
|
||||
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
|
||||
|
||||
interval := r.interval
|
||||
if interval < minInterval {
|
||||
interval = minInterval
|
||||
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := r.update(ctx); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
// Use a lower ticker interval if the update fails
|
||||
if interval > failureInterval {
|
||||
ticker.Reset(failureInterval)
|
||||
}
|
||||
} else if interval > failureInterval {
|
||||
// Reset to the original interval if the update succeeds
|
||||
ticker.Reset(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) error {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
return fmt.Errorf("resolve domains: %w", err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
return fmt.Errorf("update dynamic routes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) resolveDomains() (domainMap, error) {
|
||||
results := make(chan resolveResult)
|
||||
go r.resolve(results)
|
||||
|
||||
resolved := domainMap{}
|
||||
var merr *multierror.Error
|
||||
|
||||
for result := range results {
|
||||
if result.err != nil {
|
||||
merr = multierror.Append(merr, result.err)
|
||||
} else {
|
||||
resolved[result.domain] = append(resolved[result.domain], result.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return resolved, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) resolve(results chan resolveResult) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, d := range r.route.Domains {
|
||||
wg.Add(1)
|
||||
go func(domain domain.Domain) {
|
||||
defer wg.Done()
|
||||
ips, err := net.LookupIP(string(domain))
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
}
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
|
||||
return
|
||||
}
|
||||
results <- resolveResult{domain: domain, prefix: prefix}
|
||||
}
|
||||
}(d)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}
|
||||
|
||||
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for domain, newPrefixes := range newDomains {
|
||||
oldPrefixes := r.dynamicDomains[domain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
addedPrefixes, err := r.addRoutes(domain, toAdd)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(addedPrefixes) > 0 {
|
||||
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
removedPrefixes, err := r.removeRoutes(toRemove)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(removedPrefixes) > 0 {
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||
r.dynamicDomains[domain] = updatedPrefixes
|
||||
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
var addedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return addedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
if r.route.KeepRoute {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var removedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
removedPrefixes = append(removedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return removedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
return fmt.Errorf(addAllowedIP, prefix, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
|
||||
prefixSet := make(map[netip.Prefix]struct{})
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
for _, prefix := range removedPrefixes {
|
||||
delete(prefixSet, prefix)
|
||||
}
|
||||
for _, prefix := range addedPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
|
||||
var combinedPrefixes []netip.Prefix
|
||||
for prefix := range prefixSet {
|
||||
combinedPrefixes = append(combinedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return combinedPrefixes
|
||||
}
|
@ -2,18 +2,23 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@ -21,14 +26,9 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
|
||||
// nolint:unused
|
||||
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
@ -40,31 +40,71 @@ type Manager interface {
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||
func NewManager(
|
||||
ctx context.Context,
|
||||
pubKey string,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface *iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
initialRoutes []*route.Route,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
sysOps := systemops.NewSysOps(wgInterface)
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (any, error) {
|
||||
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ any) error {
|
||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
)
|
||||
|
||||
dm.allowedIPsRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, peerKey string) (string, error) {
|
||||
// save peerKey to use it in the remove function
|
||||
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
||||
},
|
||||
func(prefix netip.Prefix, peerKey string) error {
|
||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
|
||||
return err
|
||||
}
|
||||
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.clientRoutes(initialRoutes)
|
||||
dm.notifier.setInitialClientRoutes(cr)
|
||||
@ -73,12 +113,12 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := cleanupRouting(); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
@ -86,7 +126,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee
|
||||
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||
|
||||
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
@ -110,8 +150,19 @@ func (m *DefaultManager) Stop() {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
|
||||
if m.routeRefCounter != nil {
|
||||
if err := m.routeRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing route ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
if m.allowedIPsRefCounter != nil {
|
||||
if err := m.allowedIPsRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
@ -185,7 +236,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
@ -197,7 +248,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||
for id, client := range m.clientNetworks {
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
client.cancel()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
@ -210,7 +261,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
@ -228,7 +279,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
haID := route.GetHAUniqueID(newRoute)
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[haID] = true
|
||||
// only linux is supported for now
|
||||
@ -241,9 +292,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
haID := route.GetHAUniqueID(newRoute)
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
if !ownNetworkIDs[haID] {
|
||||
if !isPrefixSupported(newRoute.Network) {
|
||||
if !isRouteSupported(newRoute) {
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
||||
@ -255,23 +306,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0)
|
||||
rs := make([]*route.Route, 0, len(crMap))
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
func isRouteSupported(route *route.Route) bool {
|
||||
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||
return true
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
// we skip this prefix management
|
||||
if prefix.Bits() <= minRangeBits {
|
||||
if route.Network.Bits() <= vars.MinRangeBits {
|
||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||
version.NetbirdVersion(), prefix)
|
||||
version.NetbirdVersion(), route.Network)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
@ -436,7 +436,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||
if testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
||||
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
||||
}
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
@ -6,10 +6,10 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
@ -20,7 +20,7 @@ type MockManager struct {
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
|
155
client/internal/routemanager/refcounter/refcounter.go
Normal file
155
client/internal/routemanager/refcounter/refcounter.go
Normal file
@ -0,0 +1,155 @@
|
||||
package refcounter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
|
||||
var ErrIgnore = errors.New("ignore")
|
||||
|
||||
type Ref[O any] struct {
|
||||
Count int
|
||||
Out O
|
||||
}
|
||||
|
||||
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
|
||||
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
|
||||
|
||||
type Counter[I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for prefixes
|
||||
refCountMap map[netip.Prefix]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
// idMap keeps track of the prefixes associated with an ID for removal
|
||||
idMap map[string][]netip.Prefix
|
||||
idMu sync.Mutex
|
||||
add AddFunc[I, O]
|
||||
remove RemoveFunc[I, O]
|
||||
}
|
||||
|
||||
// New creates a new Counter instance
|
||||
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
|
||||
return &Counter[I, O]{
|
||||
refCountMap: map[netip.Prefix]Ref[O]{},
|
||||
idMap: map[string][]netip.Prefix{},
|
||||
add: add,
|
||||
remove: remove,
|
||||
}
|
||||
}
|
||||
|
||||
// Increment increments the reference count for the given prefix.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
|
||||
// Call AddFunc only if it's a new prefix
|
||||
if ref.Count == 0 {
|
||||
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
|
||||
out, err := rm.add(prefix, in)
|
||||
|
||||
if errors.Is(err, ErrIgnore) {
|
||||
return ref, nil
|
||||
}
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.Out = out
|
||||
}
|
||||
|
||||
ref.Count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(prefix, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
rm.idMap[id] = append(rm.idMap[id], prefix)
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Decrement decrements the reference count for the given prefix.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[prefix]
|
||||
if !ok {
|
||||
log.Tracef("No reference found for prefix %s", prefix)
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
if ref.Count == 1 {
|
||||
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.Count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
delete(rm.idMap, id)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each prefix.
|
||||
func (rm *Counter[I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Tracef("Removing for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]Ref[O]{}
|
||||
|
||||
rm.idMap = map[string][]netip.Prefix{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
7
client/internal/routemanager/refcounter/types.go
Normal file
7
client/internal/routemanager/refcounter/types.go
Normal file
@ -0,0 +1,7 @@
|
||||
package refcounter
|
||||
|
||||
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
||||
type RouteRefCounter = Counter[any, any]
|
||||
|
||||
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
||||
type AllowedIPsRefCounter = Counter[string, string]
|
@ -1,127 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ref struct {
|
||||
count int
|
||||
nexthop netip.Addr
|
||||
intf *net.Interface
|
||||
}
|
||||
|
||||
type RouteManager struct {
|
||||
// refCountMap keeps track of the reference ref for prefixes
|
||||
refCountMap map[netip.Prefix]ref
|
||||
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
||||
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
||||
addRoute AddRouteFunc
|
||||
removeRoute RemoveRouteFunc
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
|
||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
|
||||
|
||||
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||
// TODO: read initial routing table into refCountMap
|
||||
return &RouteManager{
|
||||
refCountMap: map[netip.Prefix]ref{},
|
||||
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
||||
addRoute: addRoute,
|
||||
removeRoute: removeRoute,
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
|
||||
// Add route to the system, only if it's a new prefix
|
||||
if ref.count == 0 {
|
||||
log.Debugf("Adding route for prefix %s", prefix)
|
||||
nexthop, intf, err := rm.addRoute(prefix)
|
||||
if errors.Is(err, ErrRouteNotFound) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, ErrRouteNotAllowed) {
|
||||
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.nexthop = nexthop
|
||||
ref.intf = intf
|
||||
}
|
||||
|
||||
ref.count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
prefixes, ok := rm.prefixMap[connID]
|
||||
if !ok {
|
||||
log.Debugf("No prefixes found for connection ID %s", connID)
|
||||
return nil
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
if ref.count == 1 {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
// TODO: don't fail if the route is not found
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
}
|
||||
delete(rm.prefixMap, connID)
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Flush removes all references and routes from the system
|
||||
func (rm *RouteManager) Flush() error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]ref{}
|
||||
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -70,7 +71,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
err := systemops.EnableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -88,7 +89,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
@ -117,7 +118,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
@ -133,7 +134,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
state.Routes[route.Network.String()] = struct{}{}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
@ -144,7 +151,7 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, r := range m.routes {
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
|
||||
routerPair, err := routeToRouterPair(r)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert route to router pair: %v", err)
|
||||
continue
|
||||
@ -162,15 +169,27 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||
parsed, err := netip.ParsePrefix(source)
|
||||
if err != nil {
|
||||
return firewall.RouterPair{}, err
|
||||
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
// TODO: add ipv6
|
||||
source := getDefaultPrefix(route.Network)
|
||||
|
||||
destination := route.Network.Masked().String()
|
||||
if route.IsDynamic() {
|
||||
// TODO: add ipv6
|
||||
destination = "0.0.0.0/0"
|
||||
}
|
||||
|
||||
return firewall.RouterPair{
|
||||
ID: string(route.ID),
|
||||
Source: parsed.String(),
|
||||
Destination: route.Network.Masked().String(),
|
||||
Source: source.String(),
|
||||
Destination: destination,
|
||||
Masquerade: route.Masquerade,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getDefaultPrefix(prefix netip.Prefix) netip.Prefix {
|
||||
if prefix.Addr().Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
57
client/internal/routemanager/static/route.go
Normal file
57
client/internal/routemanager/static/route.go
Normal file
@ -0,0 +1,57 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
// Route route methods
|
||||
func (r *Route) String() string {
|
||||
return r.route.Network.String()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(context.Context) error {
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) RemoveRoute() error {
|
||||
_, err := r.routeRefCounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
r.route.Network,
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
103
client/internal/routemanager/sysctl/sysctl_linux.go
Normal file
103
client/internal/routemanager/sysctl/sysctl_linux.go
Normal file
@ -0,0 +1,103 @@
|
||||
// go:build !android
|
||||
package sysctl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := Set(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = Set(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := Set(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
// Cleanup resets sysctl settings to their original values.
|
||||
func Cleanup(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := Set(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
@ -1,414 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRouteNotFound = errors.New("route not found")
|
||||
var ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
defaultGateway, _, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(defaultGateway) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||
if defaultGateway.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayHop, intf, err := GetNextHop(defaultGateway)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if preferredSrc == nil {
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
}
|
||||
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
addr, err := ipToAddr(preferredSrc, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
}
|
||||
return addr.Unmap(), intf, nil
|
||||
}
|
||||
|
||||
addr, err := ipToAddr(gateway, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
||||
}
|
||||
|
||||
return addr, intf, nil
|
||||
}
|
||||
|
||||
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||
}
|
||||
|
||||
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
|
||||
if runtime.GOOS == "windows" {
|
||||
addr = addr.WithZone(strconv.Itoa(intf.Index))
|
||||
} else {
|
||||
addr = addr.WithZone(intf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return netip.Addr{}, nil, ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, intf, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||
exitNextHop := nexthop
|
||||
exitIntf := intf
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||
exitNextHop = initialNextHop
|
||||
exitIntf = initialIntf
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, exitIntf, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == defaultv6 {
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := addRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
} else if prefix == defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return &prefix, nil
|
||||
}
|
||||
|
||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
*routeManager = NewRouteManager(
|
||||
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||
if addr.Is6() {
|
||||
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||
}
|
||||
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||
},
|
||||
removeFromRouteTable,
|
||||
)
|
||||
|
||||
return setupHooks(*routeManager, initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||
if routeManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := routeManager.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := getPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
18
client/internal/routemanager/systemops/routeflags_bsd.go
Normal file
18
client/internal/routemanager/systemops/routeflags_bsd.go
Normal file
@ -0,0 +1,18 @@
|
||||
//go:build darwin || dragonfly || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
19
client/internal/routemanager/systemops/routeflags_freebsd.go
Normal file
19
client/internal/routemanager/systemops/routeflags_freebsd.go
Normal file
@ -0,0 +1,19 @@
|
||||
//go:build: freebsd
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
|
||||
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
27
client/internal/routemanager/systemops/systemops.go
Normal file
27
client/internal/routemanager/systemops/systemops.go
Normal file
@ -0,0 +1,27 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
IP netip.Addr
|
||||
Intf *net.Interface
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -43,8 +43,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
||||
}
|
||||
|
||||
if m.Flags&syscall.RTF_UP == 0 ||
|
||||
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
if filterRoutesByFlags(m.Flags) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -93,7 +92,7 @@ func toNetIP(a route.Addr) netip.Addr {
|
||||
case *route.Inet6Addr:
|
||||
ip := netip.AddrFrom16(t.IP)
|
||||
if t.ZoneID != 0 {
|
||||
ip.WithZone(strconv.Itoa(t.ZoneID))
|
||||
ip = ip.WithZone(strconv.Itoa(t.ZoneID))
|
||||
}
|
||||
return ip
|
||||
default:
|
||||
@ -101,6 +100,7 @@ func toNetIP(a route.Addr) netip.Addr {
|
||||
}
|
||||
}
|
||||
|
||||
// ones returns the number of leading ones in the mask.
|
||||
func ones(a route.Addr) (int, error) {
|
||||
switch t := a.(type) {
|
||||
case *route.Inet4Addr:
|
||||
@ -114,6 +114,7 @@ func ones(a route.Addr) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// MsgToRoute converts a route message to a Route.
|
||||
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
|
||||
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
|
||||
|
@ -1,6 +1,6 @@
|
||||
//go:build !ios
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
var expectedVPNint = "utun100"
|
||||
@ -35,13 +36,15 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
r := NewSysOps(nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@ -57,7 +60,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@ -67,6 +70,53 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr route.Addr
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 all ones",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
|
||||
want: 32,
|
||||
},
|
||||
{
|
||||
name: "IPv4 normal mask",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
|
||||
want: 24,
|
||||
},
|
||||
{
|
||||
name: "IPv6 all ones",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
|
||||
want: 128,
|
||||
},
|
||||
{
|
||||
name: "IPv6 normal mask",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
want: 64,
|
||||
},
|
||||
{
|
||||
name: "Unsupported type",
|
||||
addr: &route.LinkAddr{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ones(tt.addr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||
t.Helper()
|
||||
|
473
client/internal/routemanager/systemops/systemops_generic.go
Normal file
473
client/internal/routemanager/systemops/systemops_generic.go
Normal file
@ -0,0 +1,473 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
refCounter := refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (Nexthop, error) {
|
||||
initialNexthop := initialNextHopV4
|
||||
if prefix.Addr().Is6() {
|
||||
initialNexthop = initialNextHopV6
|
||||
}
|
||||
|
||||
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
|
||||
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Tracef("Adding for prefix %s: %v", prefix, err)
|
||||
// These errors are not critical but also we should not track and try to remove the routes either.
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
return nexthop, err
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses)
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupRefCounter() error {
|
||||
if r.refCounter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := r.refCounter.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(nexthop.IP) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
||||
if nexthop.IP.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
nexthop, err = GetNextHop(nexthop.IP)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
||||
return r.addToRouteTable(gatewayPrefix, nexthop)
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
||||
exitNextHop := Nexthop{
|
||||
IP: nexthop.IP,
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
||||
|
||||
exitNextHop = initialNextHop
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
|
||||
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
|
||||
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
if prefix == vars.Defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
|
||||
if preferredSrc == nil {
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
}
|
||||
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
addr, err := ipToAddr(preferredSrc, intf)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
}
|
||||
return Nexthop{
|
||||
IP: addr,
|
||||
Intf: intf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
addr, err := ipToAddr(gateway, intf)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
|
||||
}
|
||||
|
||||
return Nexthop{
|
||||
IP: addr,
|
||||
Intf: intf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
|
||||
}
|
||||
|
||||
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
|
||||
zone := intf.Name
|
||||
if runtime.GOOS == "windows" {
|
||||
zone = strconv.Itoa(intf.Index)
|
||||
}
|
||||
log.Tracef("Adding zone %s to address %s", zone, addr)
|
||||
addr = addr.WithZone(zone)
|
||||
}
|
||||
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute == prefix {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
routes, err := getRoutesFromTable()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
localRoutes, err := hasSeparateRouting()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrRoutingIsSeparate) {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
}
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
return isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
}
|
||||
|
||||
func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
vpnPrefixMap := map[netip.Prefix]struct{}{}
|
||||
for _, prefix := range vpnRoutes {
|
||||
vpnPrefixMap[prefix] = struct{}{}
|
||||
}
|
||||
|
||||
// remove vpnRoute duplicates
|
||||
for _, prefix := range localRoutes {
|
||||
delete(vpnPrefixMap, prefix)
|
||||
}
|
||||
|
||||
var longestPrefix netip.Prefix
|
||||
var isVpn bool
|
||||
|
||||
combinedRoutes := make([]netip.Prefix, len(vpnRoutes)+len(localRoutes))
|
||||
copy(combinedRoutes, vpnRoutes)
|
||||
copy(combinedRoutes[len(vpnRoutes):], localRoutes)
|
||||
|
||||
for _, prefix := range combinedRoutes {
|
||||
// Ignore the default route, it has special handling
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
// Longest prefix match
|
||||
if !longestPrefix.IsValid() || prefix.Bits() > longestPrefix.Bits() {
|
||||
longestPrefix = prefix
|
||||
_, isVpn = vpnPrefixMap[prefix]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !longestPrefix.IsValid() {
|
||||
// No route matched
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -49,6 +49,10 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
for n, testCase := range testCases {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping ", testCase.name, " on freebsd")
|
||||
}
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||
|
||||
@ -57,23 +61,26 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
_, _, err = setupRouting(nil, wgInterface)
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
|
||||
_, _, err = r.SetupRouting(nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
@ -84,19 +91,19 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
|
||||
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
|
||||
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
|
||||
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
|
||||
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -104,11 +111,14 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetNextHop(t *testing.T) {
|
||||
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping on freebsd")
|
||||
}
|
||||
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if !gateway.IsValid() {
|
||||
if !nexthop.IP.IsValid() {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
@ -130,24 +140,24 @@ func TestGetNextHop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
localIP, _, err := GetNextHop(testingPrefix.Addr())
|
||||
localIP, err := GetNextHop(testingPrefix.Addr())
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if !localIP.IsValid() {
|
||||
if !localIP.IP.IsValid() {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.String() == gateway.String() {
|
||||
t.Fatal("local ip should not match with gateway IP")
|
||||
if localIP.IP.String() == nexthop.IP.String() {
|
||||
t.Fatal("local IP should not match with gateway IP")
|
||||
}
|
||||
if localIP.String() != testingIP {
|
||||
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
||||
if localIP.IP.String() != testingIP {
|
||||
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultGateway: ", defaultGateway)
|
||||
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultNexthop: ", defaultNexthop)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
@ -164,7 +174,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
|
||||
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
@ -203,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@ -214,14 +224,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := addVPNRoute(testCase.preExistingPrefix, intf)
|
||||
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
@ -231,7 +243,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
@ -295,19 +307,22 @@ func TestExistsInRouteTable(t *testing.T) {
|
||||
var addressPrefixes []netip.Prefix
|
||||
for _, address := range addresses {
|
||||
p := netip.MustParsePrefix(address.String())
|
||||
if p.Addr().Is6() {
|
||||
continue
|
||||
}
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
||||
continue
|
||||
}
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
switch {
|
||||
case p.Addr().Is6():
|
||||
continue
|
||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
||||
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
|
||||
continue
|
||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
||||
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
|
||||
continue
|
||||
// FreeBSD loopback 127/8 is not added to the routing table
|
||||
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
|
||||
continue
|
||||
default:
|
||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range addressPrefixes {
|
||||
@ -330,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
newNet, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
err = wgInterface.Create()
|
||||
@ -343,65 +358,52 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
return wgInterface
|
||||
}
|
||||
|
||||
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
err := r.AddVPNRoute(prefix, intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = r.RemoveVPNRoute(prefix, intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
setupDummyInterfacesAndRoutes(t)
|
||||
|
||||
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, wgIface.Close())
|
||||
assert.NoError(t, wgInterface.Close())
|
||||
})
|
||||
|
||||
_, _, err := setupRouting(nil, wgIface)
|
||||
r := NewSysOps(wgInterface)
|
||||
_, _, err := r.SetupRouting(nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgIface.Name())
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
|
||||
// unique route in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
@ -410,11 +412,133 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
|
||||
return
|
||||
}
|
||||
|
||||
prefixGateway, _, err := GetNextHop(prefix.Addr())
|
||||
prefixNexthop, err := GetNextHop(prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsVpnRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
vpnRoutes []string
|
||||
localRoutes []string
|
||||
expectedVpn bool
|
||||
expectedPrefix netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Match in VPN routes",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Match in local routes",
|
||||
addr: "10.1.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
addr: "172.16.0.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Default route ignored",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Default route matches but ignored",
|
||||
addr: "172.16.1.1",
|
||||
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
|
||||
localRoutes: []string{"10.0.0.0/8"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match local multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.0.0/16"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
{
|
||||
name: "Longest prefix match vpn multiple",
|
||||
addr: "192.168.0.1",
|
||||
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
|
||||
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
|
||||
expectedVpn: true,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
|
||||
},
|
||||
{
|
||||
name: "Duplicate prefix in both",
|
||||
addr: "192.168.1.1",
|
||||
vpnRoutes: []string{"192.168.1.0/24"},
|
||||
localRoutes: []string{"192.168.1.0/24"},
|
||||
expectedVpn: false,
|
||||
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, err := netip.ParseAddr(tt.addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
|
||||
}
|
||||
|
||||
var vpnRoutes, localRoutes []netip.Prefix
|
||||
for _, route := range tt.vpnRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
|
||||
}
|
||||
vpnRoutes = append(vpnRoutes, prefix)
|
||||
}
|
||||
|
||||
for _, route := range tt.localRoutes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse local route %s: %v", route, err)
|
||||
}
|
||||
localRoutes = append(localRoutes, prefix)
|
||||
}
|
||||
|
||||
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
|
||||
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
|
||||
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
|
||||
})
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -9,16 +9,15 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@ -33,16 +32,10 @@ const (
|
||||
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
|
||||
var routeManager = &RouteManager{}
|
||||
|
||||
// originalSysctl stores the original sysctl values before they are modified
|
||||
var originalSysctl map[string]int
|
||||
|
||||
@ -82,7 +75,7 @@ func getSetupRules() []ruleParams {
|
||||
}
|
||||
}
|
||||
|
||||
// setupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// SetupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
|
||||
//
|
||||
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
|
||||
@ -92,17 +85,17 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||
if isLegacy() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := setupSysctl(wgIface)
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
@ -111,7 +104,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||
if cleanErr := r.CleanupRouting(); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@ -123,7 +116,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||
setIsLegacy(true)
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
@ -132,12 +125,12 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func cleanupRouting() error {
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
if isLegacy() {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
@ -156,58 +149,58 @@ func cleanupRouting() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||
if err := sysctl.Cleanup(originalSysctl); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||
}
|
||||
originalSysctl = nil
|
||||
sysctlFailed = false
|
||||
|
||||
return result.ErrorOrNil()
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
|
||||
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||
}
|
||||
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add blackhole: %w", err)
|
||||
}
|
||||
}
|
||||
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove unreachable route: %w", err)
|
||||
}
|
||||
}
|
||||
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
@ -255,7 +248,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
}
|
||||
|
||||
// addRoute adds a route to a specific routing table identified by tableID.
|
||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
@ -268,7 +261,7 @@ func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID
|
||||
}
|
||||
route.Dst = ipNet
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
if err := addNextHop(nexthop, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
@ -327,7 +320,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
}
|
||||
|
||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@ -340,7 +333,7 @@ func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tabl
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
if err := addNextHop(nexthop, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
@ -373,11 +366,11 @@ func flushRoutes(tableID, family int) error {
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||
func EnableIPForwarding() error {
|
||||
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -481,19 +474,19 @@ func removeRule(params ruleParams) error {
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
|
||||
if intf != nil {
|
||||
route.LinkIndex = intf.Index
|
||||
func addNextHop(nexthop Nexthop, route *netlink.Route) error {
|
||||
if nexthop.Intf != nil {
|
||||
route.LinkIndex = nexthop.Intf.Index
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
route.Gw = addr.AsSlice()
|
||||
if nexthop.IP.IsValid() {
|
||||
route.Gw = nexthop.IP.AsSlice()
|
||||
|
||||
// if zone is set, it means the gateway is a link-local address, so we set the link index
|
||||
if addr.Zone() != "" && intf == nil {
|
||||
link, err := netlink.LinkByName(addr.Zone())
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
link, err := netlink.LinkByName(nexthop.IP.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
|
||||
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
@ -509,82 +502,9 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
if isLegacy() {
|
||||
return getRoutesFromTable()
|
||||
}
|
||||
|
||||
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := setSysctl(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
func cleanupSysctl(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := setSysctl(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
return nil, ErrRoutingIsSeparate
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -14,6 +14,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
@ -138,7 +140,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
||||
if dstIPNet.String() == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
@ -193,7 +195,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, ErrRouteNotFound
|
||||
return nil, 0, vars.ErrRouteNotFound
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
38
client/internal/routemanager/systemops/systemops_mobile.go
Normal file
38
client/internal/routemanager/systemops/systemops_mobile.go
Normal file
@ -0,0 +1,38 @@
|
||||
//go:build ios || android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
|
||||
return false, netip.Prefix{}
|
||||
}
|
28
client/internal/routemanager/systemops/systemops_nonlinux.go
Normal file
28
client/internal/routemanager/systemops/systemops_nonlinux.go
Normal file
@ -0,0 +1,28 @@
|
||||
//go:build !linux && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return getRoutesFromTable()
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build darwin && !ios
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -13,43 +13,41 @@ import (
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("add", prefix, nexthop, intf)
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("add", prefix, nexthop)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("delete", prefix, nexthop, intf)
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("delete", prefix, nexthop)
|
||||
}
|
||||
|
||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
inet := "-inet"
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
}
|
||||
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else if intf != nil {
|
||||
args = append(args, "-interface", intf.Name)
|
||||
if nexthop.IP.IsValid() {
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
} else if nexthop.Intf != nil {
|
||||
args = append(args, "-interface", nexthop.Intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
@ -1,10 +1,11 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -85,6 +86,10 @@ var testCases = []testCase{
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
// todo resolve test execution on freebsd
|
||||
if runtime.GOOS == "freebsd" {
|
||||
t.Skip("skipping ", tc.name, " on freebsd")
|
||||
}
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setupTestEnv(t)
|
||||
|
@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -17,8 +17,7 @@ import (
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type MSFT_NetRoute struct {
|
||||
@ -57,14 +56,42 @@ var prefixList []netip.Prefix
|
||||
var lastUpdate time.Time
|
||||
var mux = sync.Mutex{}
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.IP.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
nexthop.Intf = &net.Interface{Index: zone}
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IP.IsValid() {
|
||||
ip := nexthop.IP.WithZone("")
|
||||
args = append(args, ip.Unmap().String())
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
@ -93,7 +120,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutes() ([]Route, error) {
|
||||
var entries []MSFT_NetRoute
|
||||
|
||||
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
@ -118,6 +145,10 @@ func GetRoutes() ([]Route, error) {
|
||||
Index: int(entry.InterfaceIndex),
|
||||
Name: entry.InterfaceAlias,
|
||||
}
|
||||
|
||||
if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
|
||||
nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
|
||||
}
|
||||
}
|
||||
|
||||
routes = append(routes, Route{
|
||||
@ -157,11 +188,12 @@ func GetNeighbors() ([]Neighbor, error) {
|
||||
return neighbors, nil
|
||||
}
|
||||
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"add", prefix.String()}
|
||||
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
if nexthop.IP.IsValid() {
|
||||
ip := nexthop.IP.WithZone("")
|
||||
args = append(args, ip.Unmap().String())
|
||||
} else {
|
||||
addr := "0.0.0.0"
|
||||
if prefix.Addr().Is6() {
|
||||
@ -170,8 +202,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
|
||||
args = append(args, addr)
|
||||
}
|
||||
|
||||
if intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(intf.Index))
|
||||
if nexthop.Intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
@ -185,37 +217,6 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
if nexthop.Zone() != "" && intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
intf = &net.Interface{Index: zone}
|
||||
nexthop.WithZone("")
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IsValid() {
|
||||
nexthop.WithZone("")
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCacheDisabled() bool {
|
||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
|
||||
InterfaceIndex int `json:"InterfaceIndex"`
|
||||
InterfaceAlias string `json:"InterfaceAlias"`
|
||||
AddressFamily int `json:"AddressFamily"`
|
||||
NextHop string `json:"NextHop"`
|
||||
NextHop string `json:"Nexthop"`
|
||||
DestinationPrefix string `json:"DestinationPrefix"`
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
|
||||
host, _, err := net.SplitHostPort(destination)
|
||||
require.NoError(t, err)
|
||||
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
|
||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||||
@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
@ -1,33 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
@ -1,57 +0,0 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr route.Addr
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 all ones",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
|
||||
want: 32,
|
||||
},
|
||||
{
|
||||
name: "IPv4 normal mask",
|
||||
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
|
||||
want: 24,
|
||||
},
|
||||
{
|
||||
name: "IPv6 all ones",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
|
||||
want: 128,
|
||||
},
|
||||
{
|
||||
name: "IPv6 normal mask",
|
||||
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
|
||||
want: 64,
|
||||
},
|
||||
{
|
||||
name: "Unsupported type",
|
||||
addr: &route.LinkAddr{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ones(tt.addr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
//go:build !linux && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
29
client/internal/routemanager/util/ip.go
Normal file
29
client/internal/routemanager/util/ip.go
Normal file
@ -0,0 +1,29 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
|
||||
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return prefix, nil
|
||||
}
|
16
client/internal/routemanager/vars/vars.go
Normal file
16
client/internal/routemanager/vars/vars.go
Normal file
@ -0,0 +1,16 @@
|
||||
package vars
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const MinRangeBits = 7
|
||||
|
||||
var (
|
||||
ErrRouteNotFound = errors.New("route not found")
|
||||
ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
|
||||
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
)
|
@ -3,11 +3,11 @@ package routeselector
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
route "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
var err *multierror.Error
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
|
||||
@ -41,11 +41,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
}
|
||||
rs.selectAll = false
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
return errors.FormatErrorOrNil(err)
|
||||
}
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
@ -65,21 +61,17 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
||||
}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
var err *multierror.Error
|
||||
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
delete(rs.selectedRoutes, route)
|
||||
}
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
return errors.FormatErrorOrNil(err)
|
||||
}
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
@ -111,18 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := route.HAMap{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route3-172.16.0.0/12": {},
|
||||
"route1|10.0.0.0/8": {},
|
||||
"route2|192.168.0.0/16": {},
|
||||
"route3|172.16.0.0/12": {},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelected(routes)
|
||||
|
||||
assert.Equal(t, route.HAMap{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route1|10.0.0.0/8": {},
|
||||
"route2|192.168.0.0/16": {},
|
||||
}, filtered)
|
||||
}
|
||||
|
@ -4,21 +4,24 @@ package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
// todo: put it back
|
||||
/*
|
||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||
err := ebpfProxy.listen()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
}
|
||||
|
||||
f.ebpfProxy = ebpfProxy
|
||||
if userspace {
|
||||
return f
|
||||
}
|
||||
|
||||
*/
|
||||
ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
|
||||
err := ebpfProxy.listen()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
}
|
||||
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
|
@ -4,6 +4,6 @@ package wgproxy
|
||||
|
||||
import "context"
|
||||
|
||||
func NewFactory(ctx context.Context, wgPort int) *Factory {
|
||||
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
|
||||
return &Factory{wgPort: wgPort}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -92,6 +92,8 @@ message LoginRequest {
|
||||
repeated string extraIFaceBlacklist = 17;
|
||||
|
||||
optional bool networkMonitor = 18;
|
||||
|
||||
optional google.protobuf.Duration dnsRouteInterval = 19;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@ -145,6 +147,18 @@ message GetConfigResponse {
|
||||
|
||||
// adminURL settings value.
|
||||
string adminURL = 5;
|
||||
|
||||
string interfaceName = 6;
|
||||
|
||||
int64 wireguardPort = 7;
|
||||
|
||||
bool disableAutoConnect = 9;
|
||||
|
||||
bool serverSSHAllowed = 10;
|
||||
|
||||
bool rosenpassEnabled = 11;
|
||||
|
||||
bool rosenpassPermissive = 12;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@ -233,10 +247,17 @@ message SelectRoutesRequest {
|
||||
message SelectRoutesResponse {
|
||||
}
|
||||
|
||||
message IPList {
|
||||
repeated string ips = 1;
|
||||
}
|
||||
|
||||
|
||||
message Route {
|
||||
string ID = 1;
|
||||
string network = 2;
|
||||
bool selected = 3;
|
||||
repeated string domains = 4;
|
||||
map<string, IPList> resolvedIPs = 5;
|
||||
}
|
||||
|
||||
message DebugBundleRequest {
|
||||
|
@ -9,17 +9,19 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
NetID route.NetID
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of all available routes.
|
||||
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
|
||||
func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
@ -43,6 +45,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
||||
route := &selectRoute{
|
||||
NetID: id,
|
||||
Network: rt[0].Network,
|
||||
Domains: rt[0].Domains,
|
||||
Selected: routeSelector.IsSelected(id),
|
||||
}
|
||||
routes = append(routes, route)
|
||||
@ -63,13 +66,29 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
||||
return iPrefix < jPrefix
|
||||
})
|
||||
|
||||
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
|
||||
var pbRoutes []*proto.Route
|
||||
for _, route := range routes {
|
||||
pbRoutes = append(pbRoutes, &proto.Route{
|
||||
ID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Selected: route.Selected,
|
||||
})
|
||||
pbRoute := &proto.Route{
|
||||
ID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToSafeStringList(),
|
||||
ResolvedIPs: map[string]*proto.IPList{},
|
||||
Selected: route.Selected,
|
||||
}
|
||||
|
||||
for _, domain := range route.Domains {
|
||||
if prefixes, exists := resolvedDomains[domain]; exists {
|
||||
var ipStrings []string
|
||||
for _, prefix := range prefixes {
|
||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||
}
|
||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||
Ips: ipStrings,
|
||||
}
|
||||
}
|
||||
}
|
||||
pbRoutes = append(pbRoutes, pbRoute)
|
||||
}
|
||||
|
||||
return &proto.ListRoutesResponse{
|
||||
|
@ -365,6 +365,12 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
}
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
duration := msg.DnsRouteInterval.AsDuration()
|
||||
inputConfig.DNSRouteInterval = &duration
|
||||
s.latestConfigInput.DNSRouteInterval = &duration
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
@ -662,11 +668,17 @@ func (s *Server) GetConfig(_ context.Context, _ *proto.GetConfigRequest) (*proto
|
||||
}
|
||||
|
||||
return &proto.GetConfigResponse{
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigFile: s.latestConfigInput.ConfigPath,
|
||||
LogFile: s.logFile,
|
||||
PreSharedKey: preSharedKey,
|
||||
ManagementUrl: managementURL,
|
||||
ConfigFile: s.latestConfigInput.ConfigPath,
|
||||
LogFile: s.logFile,
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL,
|
||||
InterfaceName: s.config.WgIface,
|
||||
WireguardPort: int64(s.config.WgPort),
|
||||
DisableAutoConnect: s.config.DisableAutoConnect,
|
||||
ServerSSHAllowed: *s.config.ServerSSHAllowed,
|
||||
RosenpassEnabled: s.config.RosenpassEnabled,
|
||||
RosenpassPermissive: s.config.RosenpassPermissive,
|
||||
}, nil
|
||||
}
|
||||
func (s *Server) onSessionExpire() {
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@ -39,7 +41,7 @@ var (
|
||||
// we will use a management server started via to simulate the server and capture the number of retries
|
||||
func TestConnectWithRetryRuns(t *testing.T) {
|
||||
// start the signal server
|
||||
_, signalAddr, err := startSignal()
|
||||
_, signalAddr, err := startSignal(t)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start signal server: %v", err)
|
||||
}
|
||||
@ -106,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@ -117,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@ -141,7 +143,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return s, lis.Addr().String(), nil
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
@ -149,7 +153,9 @@ func startSignal() (*grpc.Server, string, error) {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
proto.RegisterSignalExchangeServer(s, signalServer.NewServer())
|
||||
srv, err := signalServer.NewServer(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
proto.RegisterSignalExchangeServer(s, srv)
|
||||
|
||||
go func() {
|
||||
if err = s.Serve(lis); err != nil {
|
||||
|
10
client/ssh/window_freebsd.go
Normal file
10
client/ssh/window_freebsd.go
Normal file
@ -0,0 +1,10 @@
|
||||
//go:build freebsd
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func setWinSize(file *os.File, width, height int) {
|
||||
}
|
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -33,6 +34,12 @@ type Environment struct {
|
||||
Platform string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Path string
|
||||
Exist bool
|
||||
ProcessIsRunning bool
|
||||
}
|
||||
|
||||
// Info is an object that contains machine information
|
||||
// Most of the code is taken from https://github.com/matishsiao/goInfo
|
||||
type Info struct {
|
||||
@ -51,6 +58,7 @@ type Info struct {
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
Files []File // for posture checks
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
@ -132,3 +140,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
processCheckPaths := make([]string, 0)
|
||||
for _, check := range checks {
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
}
|
||||
|
||||
files, err := checkFileAndProcess(processCheckPaths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := GetInfo(ctx)
|
||||
info.Files = files
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: kernel,
|
||||
Platform: "unknown",
|
||||
OS: "android",
|
||||
OS: "Android",
|
||||
OSVersion: osVersion(),
|
||||
Hostname: extractDeviceName(ctx, "android"),
|
||||
CPUs: runtime.NumCPU(),
|
||||
@ -44,6 +44,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
func uname() []string {
|
||||
res := run("/system/bin/uname", "-a")
|
||||
return strings.Split(res, " ")
|
||||
@ -72,5 +77,6 @@ func run(name string, arg ...string) string {
|
||||
if err != nil {
|
||||
log.Errorf("getInfo: %s", err)
|
||||
}
|
||||
return out.String()
|
||||
|
||||
return strings.TrimSpace(out.String())
|
||||
}
|
||||
|
@ -1,15 +1,18 @@
|
||||
//go:build freebsd
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@ -22,8 +25,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
out = _getInfo()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
osStr := strings.Replace(out, "\n", "", -1)
|
||||
osStr = strings.Replace(osStr, "\r\n", "", -1)
|
||||
osStr := strings.ReplaceAll(out, "\n", "")
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
|
||||
env := Environment{
|
||||
@ -31,14 +34,23 @@ func GetInfo(ctx context.Context) *Info {
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{Kernel: osInfo[0], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: osInfo[1], Environment: env}
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
return gio
|
||||
return &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
CPUs: runtime.NumCPU(),
|
||||
WiretrusteeVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
func _getInfo() string {
|
||||
@ -50,7 +62,8 @@ func _getInfo() string {
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
fmt.Println("getInfo:", err)
|
||||
log.Warnf("getInfo: %s", err)
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
// extractOsVersion extracts operating system version from context or returns the default
|
||||
func extractOsVersion(ctx context.Context, defaultName string) string {
|
||||
v, ok := ctx.Value(OsVersionCtxKey).(string)
|
||||
|
@ -28,28 +28,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
releaseInfo := _getReleaseInfo()
|
||||
for strings.Contains(info, "broken pipe") {
|
||||
releaseInfo = _getReleaseInfo()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
osRelease := strings.Split(releaseInfo, "\n")
|
||||
var osName string
|
||||
var osVer string
|
||||
for _, s := range osRelease {
|
||||
if strings.HasPrefix(s, "NAME=") {
|
||||
osName = strings.Split(s, "=")[1]
|
||||
osName = strings.ReplaceAll(osName, "\"", "")
|
||||
} else if strings.HasPrefix(s, "VERSION_ID=") {
|
||||
osVer = strings.Split(s, "=")[1]
|
||||
osVer = strings.ReplaceAll(osVer, "\"", "")
|
||||
}
|
||||
}
|
||||
|
||||
osStr := strings.ReplaceAll(info, "\n", "")
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
if osName == "" {
|
||||
osName = osInfo[3]
|
||||
}
|
||||
@ -72,7 +55,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
Kernel: osInfo[0],
|
||||
Platform: osInfo[2],
|
||||
OS: osName,
|
||||
OSVersion: osVer,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
@ -103,22 +86,12 @@ func _getInfo() string {
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func _getReleaseInfo() string {
|
||||
cmd := exec.Command("cat", "/etc/os-release")
|
||||
cmd.Stdin = strings.NewReader("some")
|
||||
var out bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Warnf("geucwReleaseInfo: %s", err)
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var si sysinfo.SysInfo
|
||||
si.GetSysInfo()
|
||||
return si.Chassis.Serial, si.Product.Name, si.Product.Vendor
|
||||
serial := si.Chassis.Serial
|
||||
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
|
||||
serial = si.Product.Serial
|
||||
}
|
||||
return serial, si.Product.Name, si.Product.Vendor
|
||||
}
|
||||
|
38
client/system/osrelease_unix.go
Normal file
38
client/system/osrelease_unix.go
Normal file
@ -0,0 +1,38 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func readOsReleaseFile() (osName string, osVer string) {
|
||||
file, err := os.Open("/etc/os-release")
|
||||
if err != nil {
|
||||
log.Warnf("failed to open file /etc/os-release: %s", err)
|
||||
return "", ""
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "NAME=") {
|
||||
osName = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "VERSION_ID=") {
|
||||
osVer = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
|
||||
continue
|
||||
}
|
||||
|
||||
if osName != "" && osVer != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
58
client/system/process.go
Normal file
58
client/system/process.go
Normal file
@ -0,0 +1,58 @@
|
||||
//go:build windows || (linux && !android) || (darwin && !ios) || freebsd
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
// getRunningProcesses returns a list of running process paths.
|
||||
func getRunningProcesses() ([]string, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processMap := make(map[string]bool)
|
||||
for _, p := range processes {
|
||||
path, _ := p.Exe()
|
||||
if path != "" {
|
||||
processMap[path] = true
|
||||
}
|
||||
}
|
||||
|
||||
uniqueProcesses := make([]string, 0, len(processMap))
|
||||
for p := range processMap {
|
||||
uniqueProcesses = append(uniqueProcesses, p)
|
||||
}
|
||||
|
||||
return uniqueProcesses, nil
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
files := make([]File, len(paths))
|
||||
if len(paths) == 0 {
|
||||
return files, nil
|
||||
}
|
||||
|
||||
runningProcesses, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
file := File{Path: path}
|
||||
|
||||
_, err := os.Stat(path)
|
||||
file.Exist = !os.IsNotExist(err)
|
||||
|
||||
file.ProcessIsRunning = slices.Contains(runningProcesses, path)
|
||||
files[i] = file
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
@ -1,10 +1,11 @@
|
||||
//go:build !(linux && 386)
|
||||
//go:build !(linux && 386) && !freebsd
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
@ -79,6 +80,7 @@ func main() {
|
||||
log.Errorf("check PID file: %v", err)
|
||||
return
|
||||
}
|
||||
client.setDefaultFonts()
|
||||
systray.Run(client.onTrayReady, client.onTrayExit)
|
||||
}
|
||||
}
|
||||
@ -125,44 +127,55 @@ type serviceClient struct {
|
||||
icUpdateCloud []byte
|
||||
|
||||
// systray menu items
|
||||
mStatus *systray.MenuItem
|
||||
mUp *systray.MenuItem
|
||||
mDown *systray.MenuItem
|
||||
mAdminPanel *systray.MenuItem
|
||||
mSettings *systray.MenuItem
|
||||
mAbout *systray.MenuItem
|
||||
mVersionUI *systray.MenuItem
|
||||
mVersionDaemon *systray.MenuItem
|
||||
mUpdate *systray.MenuItem
|
||||
mQuit *systray.MenuItem
|
||||
mRoutes *systray.MenuItem
|
||||
mStatus *systray.MenuItem
|
||||
mUp *systray.MenuItem
|
||||
mDown *systray.MenuItem
|
||||
mAdminPanel *systray.MenuItem
|
||||
mSettings *systray.MenuItem
|
||||
mAbout *systray.MenuItem
|
||||
mVersionUI *systray.MenuItem
|
||||
mVersionDaemon *systray.MenuItem
|
||||
mUpdate *systray.MenuItem
|
||||
mQuit *systray.MenuItem
|
||||
mRoutes *systray.MenuItem
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mAdvancedSettings *systray.MenuItem
|
||||
|
||||
// application with main windows.
|
||||
app fyne.App
|
||||
wSettings fyne.Window
|
||||
showSettings bool
|
||||
sendNotification bool
|
||||
app fyne.App
|
||||
wSettings fyne.Window
|
||||
showAdvancedSettings bool
|
||||
sendNotification bool
|
||||
|
||||
// input elements for settings form
|
||||
iMngURL *widget.Entry
|
||||
iAdminURL *widget.Entry
|
||||
iConfigFile *widget.Entry
|
||||
iLogFile *widget.Entry
|
||||
iPreSharedKey *widget.Entry
|
||||
iMngURL *widget.Entry
|
||||
iAdminURL *widget.Entry
|
||||
iConfigFile *widget.Entry
|
||||
iLogFile *widget.Entry
|
||||
iPreSharedKey *widget.Entry
|
||||
iInterfaceName *widget.Entry
|
||||
iInterfacePort *widget.Entry
|
||||
|
||||
// switch elements for settings form
|
||||
sRosenpassPermissive *widget.Check
|
||||
|
||||
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
||||
managementURL string
|
||||
preSharedKey string
|
||||
adminURL string
|
||||
managementURL string
|
||||
preSharedKey string
|
||||
adminURL string
|
||||
RosenpassPermissive bool
|
||||
interfaceName string
|
||||
interfacePort int
|
||||
|
||||
connected bool
|
||||
update *version.Update
|
||||
daemonVersion string
|
||||
updateIndicationLock sync.Mutex
|
||||
isUpdateIconActive bool
|
||||
|
||||
showRoutes bool
|
||||
wRoutes fyne.Window
|
||||
showRoutes bool
|
||||
wRoutes fyne.Window
|
||||
}
|
||||
|
||||
// newServiceClient instance constructor
|
||||
@ -175,9 +188,9 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo
|
||||
app: a,
|
||||
sendNotification: false,
|
||||
|
||||
showSettings: showSettings,
|
||||
showRoutes: showRoutes,
|
||||
update: version.NewUpdate(),
|
||||
showAdvancedSettings: showSettings,
|
||||
showRoutes: showRoutes,
|
||||
update: version.NewUpdate(),
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
@ -215,8 +228,13 @@ func (s *serviceClient) showSettingsUI() {
|
||||
s.iLogFile = widget.NewEntry()
|
||||
s.iLogFile.Disable()
|
||||
s.iPreSharedKey = widget.NewPasswordEntry()
|
||||
s.iInterfaceName = widget.NewEntry()
|
||||
s.iInterfacePort = widget.NewEntry()
|
||||
s.sRosenpassPermissive = widget.NewCheck("Enable Rosenpass permissive mode", nil)
|
||||
|
||||
s.wSettings.SetContent(s.getSettingsForm())
|
||||
s.wSettings.Resize(fyne.NewSize(600, 100))
|
||||
s.wSettings.Resize(fyne.NewSize(600, 400))
|
||||
s.wSettings.SetFixedSize(true)
|
||||
|
||||
s.getSrvConfig()
|
||||
|
||||
@ -239,6 +257,9 @@ func showErrorMSG(msg string) {
|
||||
func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
return &widget.Form{
|
||||
Items: []*widget.FormItem{
|
||||
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
|
||||
{Text: "Interface Name", Widget: s.iInterfaceName},
|
||||
{Text: "Interface Port", Widget: s.iInterfacePort},
|
||||
{Text: "Management URL", Widget: s.iMngURL},
|
||||
{Text: "Admin URL", Widget: s.iAdminURL},
|
||||
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
|
||||
@ -255,45 +276,45 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
}
|
||||
}
|
||||
|
||||
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
|
||||
if err != nil {
|
||||
dialog.ShowError(errors.New("Invalid interface port"), s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
iAdminURL := strings.TrimSpace(s.iAdminURL.Text)
|
||||
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
||||
|
||||
defer s.wSettings.Close()
|
||||
// if management URL or Pre-shared key changed, we try to re-login with new settings.
|
||||
if s.managementURL != s.iMngURL.Text || s.preSharedKey != s.iPreSharedKey.Text ||
|
||||
s.adminURL != s.iAdminURL.Text {
|
||||
|
||||
s.managementURL = s.iMngURL.Text
|
||||
// If the management URL, pre-shared key, admin URL, Rosenpass permissive mode,
|
||||
// interface name, or interface port have changed, we attempt to re-login with the new settings.
|
||||
if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text ||
|
||||
s.adminURL != iAdminURL || s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
|
||||
s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) {
|
||||
|
||||
s.managementURL = iMngURL
|
||||
s.preSharedKey = s.iPreSharedKey.Text
|
||||
s.adminURL = s.iAdminURL.Text
|
||||
|
||||
client, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get daemon client: %v", err)
|
||||
return
|
||||
}
|
||||
s.adminURL = iAdminURL
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
ManagementUrl: s.iMngURL.Text,
|
||||
AdminURL: s.iAdminURL.Text,
|
||||
ManagementUrl: iMngURL,
|
||||
AdminURL: iAdminURL,
|
||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||
RosenpassPermissive: &s.sRosenpassPermissive.Checked,
|
||||
InterfaceName: &s.iInterfaceName.Text,
|
||||
WireguardPort: &port,
|
||||
}
|
||||
|
||||
if s.iPreSharedKey.Text != "**********" {
|
||||
loginRequest.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||
}
|
||||
|
||||
_, err = client.Login(s.ctx, &loginRequest)
|
||||
if err != nil {
|
||||
log.Errorf("login to management URL: %v", err)
|
||||
if err := s.restartClient(&loginRequest); err != nil {
|
||||
log.Errorf("restarting client connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = client.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("login to management URL: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
s.wSettings.Close()
|
||||
},
|
||||
OnCancel: func() {
|
||||
s.wSettings.Close()
|
||||
@ -499,7 +520,14 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mDown.Disable()
|
||||
s.mAdminPanel = systray.AddMenuItem("Admin Panel", "Netbird Admin Panel")
|
||||
systray.AddSeparator()
|
||||
|
||||
s.mSettings = systray.AddMenuItem("Settings", "Settings of the application")
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", "Allow SSH connections", false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", "Connect automatically when the service starts", false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", "Enable post-quantum security via Rosenpass", false)
|
||||
s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application")
|
||||
s.loadSettings()
|
||||
|
||||
s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window")
|
||||
s.mRoutes.Disable()
|
||||
systray.AddSeparator()
|
||||
@ -539,7 +567,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
case <-s.mAdminPanel.ClickedCh:
|
||||
err = open.Run(s.adminURL)
|
||||
case <-s.mUp.ClickedCh:
|
||||
s.mUp.Disabled()
|
||||
s.mUp.Disable()
|
||||
go func() {
|
||||
defer s.mUp.Enable()
|
||||
err := s.menuUpClick()
|
||||
@ -558,10 +586,40 @@ func (s *serviceClient) onTrayReady() {
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-s.mSettings.ClickedCh:
|
||||
s.mSettings.Disable()
|
||||
case <-s.mAllowSSH.ClickedCh:
|
||||
if s.mAllowSSH.Checked() {
|
||||
s.mAllowSSH.Uncheck()
|
||||
} else {
|
||||
s.mAllowSSH.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
return
|
||||
}
|
||||
case <-s.mAutoConnect.ClickedCh:
|
||||
if s.mAutoConnect.Checked() {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
s.mAutoConnect.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
return
|
||||
}
|
||||
case <-s.mEnableRosenpass.ClickedCh:
|
||||
if s.mEnableRosenpass.Checked() {
|
||||
s.mEnableRosenpass.Uncheck()
|
||||
} else {
|
||||
s.mEnableRosenpass.Check()
|
||||
}
|
||||
if err := s.updateConfig(); err != nil {
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
return
|
||||
}
|
||||
case <-s.mAdvancedSettings.ClickedCh:
|
||||
s.mAdvancedSettings.Disable()
|
||||
go func() {
|
||||
defer s.mSettings.Enable()
|
||||
defer s.mAdvancedSettings.Enable()
|
||||
defer s.getSrvConfig()
|
||||
s.runSelfCommand("settings", "true")
|
||||
}()
|
||||
@ -663,13 +721,23 @@ func (s *serviceClient) getSrvConfig() {
|
||||
s.adminURL = cfg.AdminURL
|
||||
}
|
||||
s.preSharedKey = cfg.PreSharedKey
|
||||
s.RosenpassPermissive = cfg.RosenpassPermissive
|
||||
s.interfaceName = cfg.InterfaceName
|
||||
s.interfacePort = int(cfg.WireguardPort)
|
||||
|
||||
if s.showSettings {
|
||||
if s.showAdvancedSettings {
|
||||
s.iMngURL.SetText(s.managementURL)
|
||||
s.iAdminURL.SetText(s.adminURL)
|
||||
s.iConfigFile.SetText(cfg.ConfigFile)
|
||||
s.iLogFile.SetText(cfg.LogFile)
|
||||
s.iPreSharedKey.SetText(cfg.PreSharedKey)
|
||||
s.iInterfaceName.SetText(cfg.InterfaceName)
|
||||
s.iInterfacePort.SetText(strconv.Itoa(int(cfg.WireguardPort)))
|
||||
s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive)
|
||||
if !cfg.RosenpassEnabled {
|
||||
s.sRosenpassPermissive.Disable()
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -704,6 +772,81 @@ func (s *serviceClient) onSessionExpire() {
|
||||
}
|
||||
}
|
||||
|
||||
// loadSettings loads the settings from the config file and updates the UI elements accordingly.
|
||||
func (s *serviceClient) loadSettings() {
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get config settings from server: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.ServerSSHAllowed {
|
||||
s.mAllowSSH.Check()
|
||||
} else {
|
||||
s.mAllowSSH.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.DisableAutoConnect {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
s.mAutoConnect.Check()
|
||||
}
|
||||
|
||||
if cfg.RosenpassEnabled {
|
||||
s.mEnableRosenpass.Check()
|
||||
} else {
|
||||
s.mEnableRosenpass.Uncheck()
|
||||
}
|
||||
}
|
||||
|
||||
// updateConfig updates the configuration parameters
|
||||
// based on the values selected in the settings window.
|
||||
func (s *serviceClient) updateConfig() error {
|
||||
disableAutoStart := !s.mAutoConnect.Checked()
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
IsLinuxDesktopClient: runtime.GOOS == "linux",
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
}
|
||||
|
||||
if err := s.restartClient(&loginRequest); err != nil {
|
||||
log.Errorf("restarting client connection: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restartClient restarts the client connection.
|
||||
func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
|
||||
client, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.Login(s.ctx, loginRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = client.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openURL(url string) error {
|
||||
var err error
|
||||
switch runtime.GOOS {
|
||||
@ -734,3 +877,88 @@ func checkPIDFile() error {
|
||||
|
||||
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
|
||||
}
|
||||
|
||||
func (s *serviceClient) setDefaultFonts() {
|
||||
var (
|
||||
defaultFontPath string
|
||||
)
|
||||
|
||||
//TODO: Linux Multiple Language Support
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
|
||||
case "windows":
|
||||
fontPath := s.getWindowsFontFilePath()
|
||||
defaultFontPath = fontPath
|
||||
}
|
||||
|
||||
_, err := os.Stat(defaultFontPath)
|
||||
|
||||
if err == nil {
|
||||
os.Setenv("FYNE_FONT", defaultFontPath)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
|
||||
/*
|
||||
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
|
||||
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
|
||||
*/
|
||||
|
||||
var (
|
||||
fontFolder string = "C:/Windows/Fonts"
|
||||
fontMapping = map[string]string{
|
||||
"default": "Segoeui.ttf",
|
||||
"zh-CN": "Msyh.ttc",
|
||||
"am-ET": "Ebrima.ttf",
|
||||
"nirmala": "Nirmala.ttf",
|
||||
"chr-CHER-US": "Gadugi.ttf",
|
||||
"zh-HK": "Msjh.ttc",
|
||||
"zh-TW": "Msjh.ttc",
|
||||
"ja-JP": "Yugothm.ttc",
|
||||
"km-KH": "Leelawui.ttf",
|
||||
"ko-KR": "Malgun.ttf",
|
||||
"th-TH": "Leelawui.ttf",
|
||||
"ti-ET": "Ebrima.ttf",
|
||||
}
|
||||
nirMalaLang = []string{
|
||||
"as-IN",
|
||||
"bn-BD",
|
||||
"bn-IN",
|
||||
"gu-IN",
|
||||
"hi-IN",
|
||||
"kn-IN",
|
||||
"kok-IN",
|
||||
"ml-IN",
|
||||
"mr-IN",
|
||||
"ne-NP",
|
||||
"or-IN",
|
||||
"pa-IN",
|
||||
"si-LK",
|
||||
"ta-IN",
|
||||
"te-IN",
|
||||
}
|
||||
)
|
||||
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get Windows default language setting: %v", err)
|
||||
fontPath = path.Join(fontFolder, fontMapping["default"])
|
||||
return
|
||||
}
|
||||
defaultLanguage := strings.TrimSpace(string(output))
|
||||
|
||||
for _, lang := range nirMalaLang {
|
||||
if defaultLanguage == lang {
|
||||
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if font, ok := fontMapping[defaultLanguage]; ok {
|
||||
fontPath = path.Join(fontFolder, font)
|
||||
} else {
|
||||
fontPath = path.Join(fontFolder, fontMapping["default"])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
//go:build !(linux && 386)
|
||||
//go:build !(linux && 386) && !freebsd
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -17,28 +18,57 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
allRoutesText = "All routes"
|
||||
overlappingRoutesText = "Overlapping routes"
|
||||
exitNodeRoutesText = "Exit-node routes"
|
||||
allRoutes filter = "all"
|
||||
overlappingRoutes filter = "overlapping"
|
||||
exitNodeRoutes filter = "exit-node"
|
||||
getClientFMT = "get client: %v"
|
||||
)
|
||||
|
||||
type filter string
|
||||
|
||||
func (s *serviceClient) showRoutesUI() {
|
||||
s.wRoutes = s.app.NewWindow("NetBird Routes")
|
||||
|
||||
grid := container.New(layout.NewGridLayout(2))
|
||||
go s.updateRoutes(grid)
|
||||
allGrid := container.New(layout.NewGridLayout(3))
|
||||
go s.updateRoutes(allGrid, allRoutes)
|
||||
overlappingGrid := container.New(layout.NewGridLayout(3))
|
||||
exitNodeGrid := container.New(layout.NewGridLayout(3))
|
||||
routeCheckContainer := container.NewVBox()
|
||||
routeCheckContainer.Add(grid)
|
||||
tabs := container.NewAppTabs(
|
||||
container.NewTabItem(allRoutesText, allGrid),
|
||||
container.NewTabItem(overlappingRoutesText, overlappingGrid),
|
||||
container.NewTabItem(exitNodeRoutesText, exitNodeGrid),
|
||||
)
|
||||
tabs.OnSelected = func(item *container.TabItem) {
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}
|
||||
tabs.OnUnselected = func(item *container.TabItem) {
|
||||
grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
grid.Objects = nil
|
||||
}
|
||||
|
||||
routeCheckContainer.Add(tabs)
|
||||
scrollContainer := container.NewVScroll(routeCheckContainer)
|
||||
scrollContainer.SetMinSize(fyne.NewSize(200, 300))
|
||||
|
||||
buttonBox := container.NewHBox(
|
||||
layout.NewSpacer(),
|
||||
widget.NewButton("Refresh", func() {
|
||||
s.updateRoutes(grid)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}),
|
||||
widget.NewButton("Select all", func() {
|
||||
s.selectAllRoutes()
|
||||
s.updateRoutes(grid)
|
||||
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
s.selectAllFilteredRoutes(f)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}),
|
||||
widget.NewButton("Deselect All", func() {
|
||||
s.deselectAllRoutes()
|
||||
s.updateRoutes(grid)
|
||||
_, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
s.deselectAllFilteredRoutes(f)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}),
|
||||
layout.NewSpacer(),
|
||||
)
|
||||
@ -48,27 +78,31 @@ func (s *serviceClient) showRoutesUI() {
|
||||
s.wRoutes.SetContent(content)
|
||||
s.wRoutes.Show()
|
||||
|
||||
s.startAutoRefresh(5*time.Second, grid)
|
||||
s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid)
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
routes, err := s.fetchRoutes()
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
s.showError(fmt.Errorf("get client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) {
|
||||
grid.Objects = nil
|
||||
grid.Refresh()
|
||||
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
|
||||
grid.Add(idHeader)
|
||||
grid.Add(networkHeader)
|
||||
for _, route := range routes {
|
||||
grid.Add(resolvedIPsHeader)
|
||||
|
||||
filteredRoutes, err := s.getFilteredRoutes(f)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sortRoutesByIDs(filteredRoutes)
|
||||
|
||||
for _, route := range filteredRoutes {
|
||||
r := route
|
||||
|
||||
checkBox := widget.NewCheck(r.ID, func(checked bool) {
|
||||
checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
|
||||
s.selectRoute(r.ID, checked)
|
||||
})
|
||||
checkBox.Checked = route.Selected
|
||||
@ -76,16 +110,106 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
checkBox.Refresh()
|
||||
|
||||
grid.Add(checkBox)
|
||||
grid.Add(widget.NewLabel(r.Network))
|
||||
network := r.GetNetwork()
|
||||
domains := r.GetDomains()
|
||||
|
||||
if len(domains) == 0 {
|
||||
grid.Add(widget.NewLabel(network))
|
||||
grid.Add(widget.NewLabel(""))
|
||||
continue
|
||||
}
|
||||
|
||||
// our selectors are only for display
|
||||
noopFunc := func(_ string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
domainsSelector := widget.NewSelect(domains, noopFunc)
|
||||
domainsSelector.Selected = domains[0]
|
||||
grid.Add(domainsSelector)
|
||||
|
||||
var resolvedIPsList []string
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
|
||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
if len(resolvedIPsList) == 0 {
|
||||
grid.Add(widget.NewLabel(""))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: limit width within the selector display
|
||||
resolvedIPsSelector := widget.NewSelect(resolvedIPsList, noopFunc)
|
||||
resolvedIPsSelector.Selected = resolvedIPsList[0]
|
||||
resolvedIPsSelector.Resize(fyne.NewSize(100, 100))
|
||||
grid.Add(resolvedIPsSelector)
|
||||
}
|
||||
|
||||
s.wRoutes.Content().Refresh()
|
||||
grid.Refresh()
|
||||
}
|
||||
|
||||
func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) {
|
||||
routes, err := s.fetchRoutes()
|
||||
if err != nil {
|
||||
log.Errorf(getClientFMT, err)
|
||||
s.showError(fmt.Errorf(getClientFMT, err))
|
||||
return nil, err
|
||||
}
|
||||
switch f {
|
||||
case overlappingRoutes:
|
||||
return getOverlappingRoutes(routes), nil
|
||||
case exitNodeRoutes:
|
||||
return getExitNodeRoutes(routes), nil
|
||||
default:
|
||||
}
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func getOverlappingRoutes(routes []*proto.Route) []*proto.Route {
|
||||
var filteredRoutes []*proto.Route
|
||||
existingRange := make(map[string][]*proto.Route)
|
||||
for _, route := range routes {
|
||||
if len(route.Domains) > 0 {
|
||||
continue
|
||||
}
|
||||
if r, exists := existingRange[route.GetNetwork()]; exists {
|
||||
r = append(r, route)
|
||||
existingRange[route.GetNetwork()] = r
|
||||
} else {
|
||||
existingRange[route.GetNetwork()] = []*proto.Route{route}
|
||||
}
|
||||
}
|
||||
for _, r := range existingRange {
|
||||
if len(r) > 1 {
|
||||
filteredRoutes = append(filteredRoutes, r...)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
func getExitNodeRoutes(routes []*proto.Route) []*proto.Route {
|
||||
var filteredRoutes []*proto.Route
|
||||
for _, route := range routes {
|
||||
if route.Network == "0.0.0.0/0" {
|
||||
filteredRoutes = append(filteredRoutes, route)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
func sortRoutesByIDs(routes []*proto.Route) {
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
return nil, fmt.Errorf(getClientFMT, err)
|
||||
}
|
||||
|
||||
resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{})
|
||||
@ -99,8 +223,8 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
|
||||
func (s *serviceClient) selectRoute(id string, checked bool) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
s.showError(fmt.Errorf("get client: %v", err))
|
||||
log.Errorf(getClientFMT, err)
|
||||
s.showError(fmt.Errorf(getClientFMT, err))
|
||||
return
|
||||
}
|
||||
|
||||
@ -126,16 +250,14 @@ func (s *serviceClient) selectRoute(id string, checked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) selectAllRoutes() {
|
||||
func (s *serviceClient) selectAllFilteredRoutes(f filter) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
log.Errorf(getClientFMT, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
All: true,
|
||||
}
|
||||
req := s.getRoutesRequest(f, true)
|
||||
if _, err := conn.SelectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to select all routes: %v", err)
|
||||
s.showError(fmt.Errorf("failed to select all routes: %v", err))
|
||||
@ -145,16 +267,14 @@ func (s *serviceClient) selectAllRoutes() {
|
||||
log.Debug("All routes selected")
|
||||
}
|
||||
|
||||
func (s *serviceClient) deselectAllRoutes() {
|
||||
func (s *serviceClient) deselectAllFilteredRoutes(f filter) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
log.Errorf(getClientFMT, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &proto.SelectRoutesRequest{
|
||||
All: true,
|
||||
}
|
||||
req := s.getRoutesRequest(f, false)
|
||||
if _, err := conn.DeselectRoutes(s.ctx, req); err != nil {
|
||||
log.Errorf("failed to deselect all routes: %v", err)
|
||||
s.showError(fmt.Errorf("failed to deselect all routes: %v", err))
|
||||
@ -164,17 +284,34 @@ func (s *serviceClient) deselectAllRoutes() {
|
||||
log.Debug("All routes deselected")
|
||||
}
|
||||
|
||||
func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest {
|
||||
req := &proto.SelectRoutesRequest{}
|
||||
if f == allRoutes {
|
||||
req.All = true
|
||||
} else {
|
||||
routes, err := s.getFilteredRoutes(f)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, route := range routes {
|
||||
req.RouteIDs = append(req.RouteIDs, route.GetID())
|
||||
}
|
||||
req.Append = appendRoute
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func (s *serviceClient) showError(err error) {
|
||||
wrappedMessage := wrapText(err.Error(), 50)
|
||||
|
||||
dialog.ShowError(fmt.Errorf("%s", wrappedMessage), s.wRoutes)
|
||||
}
|
||||
|
||||
func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Container) {
|
||||
func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
|
||||
ticker := time.NewTicker(interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
s.updateRoutes(grid)
|
||||
s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -183,6 +320,23 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, grid *fyne.Cont
|
||||
})
|
||||
}
|
||||
|
||||
func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) {
|
||||
grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid)
|
||||
s.wRoutes.Content().Refresh()
|
||||
s.updateRoutes(grid, f)
|
||||
}
|
||||
|
||||
func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) {
|
||||
switch tabs.Selected().Text {
|
||||
case overlappingRoutesText:
|
||||
return overlappingGrid, overlappingRoutes
|
||||
case exitNodeRoutesText:
|
||||
return exitNodesGrid, exitNodeRoutes
|
||||
default:
|
||||
return allGrid, allRoutes
|
||||
}
|
||||
}
|
||||
|
||||
// wrapText inserts newlines into the text to ensure that each line is
|
||||
// no longer than 'lineLength' runes.
|
||||
func wrapText(text string, lineLength int) string {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user