Release 0.28.0 (#2092)

* compile client under freebsd (#1620)

Compile netbird client under freebsd and now support netstack and userspace modes.
Refactoring linux specific code to share same code with FreeBSD, move to *_unix.go files.

Not implemented yet:

Kernel mode not supported
DNS probably does not work yet
Routing also probably does not work yet
SSH support did not tested yet
Lack of test environment for freebsd (dedicated VM for github runners under FreeBSD required)
Lack of tests for freebsd specific code
info reporting need to review and also implement, for example OS reported as GENERIC instead of FreeBSD (lack of FreeBSD icon in management interface)
Lack of proper client setup under FreeBSD
Lack of FreeBSD port/package

* Add DNS routes (#1943)

Given domains are resolved periodically and resolved IPs are replaced with the new ones. Unless the flag keep_route is set to true, then only new ones are added.
This option is helpful if there are long-running connections that might still point to old IP addresses from changed DNS records.

* Add process posture check (#1693)

Introduces a process posture check to validate the existence and active status of specific binaries on peer systems. The check ensures that files are present at specified paths, and that corresponding processes are running. This check supports Linux, Windows, and macOS systems.


Co-authored-by: Evgenii <mail@skillcoder.com>
Co-authored-by: Pascal Fischer <pascal@netbird.io>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Maycon Santos 2024-06-13 13:24:24 +02:00 committed by GitHub
parent 95299be52d
commit 4fec709bb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
149 changed files with 6509 additions and 2710 deletions

View File

@ -86,7 +86,10 @@ jobs:
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/... run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin - name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@ -108,6 +111,9 @@ jobs:
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker - name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@ -36,6 +36,7 @@ const (
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
) )
var ( var (
@ -68,6 +69,8 @@ var (
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool anonymizeFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
Short: "", Short: "",

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -66,18 +67,60 @@ func routesList(cmd *cobra.Command, _ []string) error {
return nil return nil
} }
cmd.Println("Available Routes:") printRoutes(cmd, resp)
for _, route := range resp.Routes {
selectedStatus := "Not Selected"
if route.GetSelected() {
selectedStatus = "Selected"
}
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
return nil return nil
} }
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error { func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd) conn, err := getClient(cmd)
if err != nil { if err != nil {

View File

@ -807,11 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
} }
for i, route := range peer.Routes { for i, route := range peer.Routes {
prefix, err := netip.ParsePrefix(route) peer.Routes[i] = anonymizeRoute(a, route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
} }
} }
@ -847,12 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
} }
for i, route := range overview.Routes { for i, route := range overview.Routes {
prefix, err := netip.ParsePrefix(route) overview.Routes[i] = anonymizeRoute(a, route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
} }
overview.FQDN = a.AnonymizeDomain(overview.FQDN) overview.FQDN = a.AnonymizeDomain(overview.FQDN)
} }
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@ -7,11 +7,13 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@ -42,6 +44,7 @@ func init() {
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring") upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@ -137,6 +140,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
} }
} }
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@ -237,6 +244,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.NetworkMonitor = &networkMonitor loginRequest.NetworkMonitor = &networkMonitor
} }
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

30
client/errors/errors.go Normal file
View File

@ -0,0 +1,30 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@ -7,12 +7,14 @@ import (
"os" "os"
"reflect" "reflect"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
@ -53,6 +55,7 @@ type ConfigInput struct {
NetworkMonitor *bool NetworkMonitor *bool
DisableAutoConnect *bool DisableAutoConnect *bool
ExtraIFaceBlackList []string ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
} }
// Config Configuration type // Config Configuration type
@ -95,6 +98,9 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service // DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility // it's set to false by default due to backwards compatibility
DisableAutoConnect bool DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@ -357,6 +363,18 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
return updated, nil return updated, nil
} }

View File

@ -252,8 +252,10 @@ func (c *ConnectClient) run(
return wrapErr(err) return wrapErr(err)
} }
checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engineMutex.Unlock() c.engineMutex.Unlock()
err = c.engine.Start() err = c.engine.Start()
@ -321,6 +323,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
RosenpassEnabled: config.RosenpassEnabled, RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive, RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {

View File

@ -0,0 +1,6 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@ -0,0 +1,8 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns
@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil return networkManager, nil
} }
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() { if checkStub() {
return systemdManager, nil return systemdManager, nil
} else { } else {
@ -116,16 +116,10 @@ func getOSDNSManagerType() (osManagerType, error) {
} }
} }
if strings.Contains(text, "resolvconf") { if strings.Contains(text, "resolvconf") {
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { if isSystemdResolveConfMode() {
var value string
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
if err == nil {
if value == systemdDbusResolvConfModeForeign {
return systemdManager, nil return systemdManager, nil
} }
}
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
}
return resolvConfManager, nil return resolvConfManager, nil
} }
} }

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -39,6 +39,10 @@ func (w *mocWGIface) Address() iface.WGAddress {
} }
} }
func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
func (w *mocWGIface) GetFilter() iface.PacketFilter { func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter return w.filter
} }

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns

View File

@ -0,0 +1,20 @@
package dns
import (
"errors"
"fmt"
)
var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
}
func isSystemdResolvedRunning() bool {
return false
}
func isSystemdResolveConfMode() bool {
return false
}

View File

@ -242,3 +242,25 @@ func getSystemdDbusProperty(property string, store any) error {
return v.Store(store) return v.Store(store)
} }
func isSystemdResolvedRunning() bool {
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
}
func isSystemdResolveConfMode() bool {
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
return false
}
var value string
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
return false
}
if value == systemdDbusResolvConfModeForeign {
return true
}
return false
}

View File

@ -1,4 +1,4 @@
//go:build !android //go:build (linux && !android) || freebsd
package dns package dns
@ -14,11 +14,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)
func CheckUncleanShutdown(wgIface string) error { func CheckUncleanShutdown(wgIface string) error {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

View File

@ -2,12 +2,17 @@
package dns package dns
import "github.com/netbirdio/netbird/iface" import (
"net"
"github.com/netbirdio/netbird/iface"
)
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() iface.PacketFilter GetFilter() iface.PacketFilter
GetDevice() *iface.DeviceWrapper GetDevice() *iface.DeviceWrapper

View File

@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -30,10 +31,12 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@ -89,6 +92,8 @@ type EngineConfig struct {
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DNSRouteInterval time.Duration
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@ -154,6 +159,9 @@ type Engine struct {
wgProbe *Probe wgProbe *Probe
wgConnWorker sync.WaitGroup wgConnWorker sync.WaitGroup
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -171,6 +179,7 @@ func NewEngine(
config *EngineConfig, config *EngineConfig,
mobileDep MobileDependency, mobileDep MobileDependency,
statusRecorder *peer.Status, statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return NewEngineWithProbes( return NewEngineWithProbes(
clientCtx, clientCtx,
@ -184,6 +193,7 @@ func NewEngine(
nil, nil,
nil, nil,
nil, nil,
checks,
) )
} }
@ -200,6 +210,7 @@ func NewEngineWithProbes(
signalProbe *Probe, signalProbe *Probe,
relayProbe *Probe, relayProbe *Probe,
wgProbe *Probe, wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return &Engine{ return &Engine{
@ -220,6 +231,7 @@ func NewEngineWithProbes(
signalProbe: signalProbe, signalProbe: signalProbe,
relayProbe: relayProbe, relayProbe: relayProbe,
wgProbe: wgProbe, wgProbe: wgProbe,
checks: checks,
} }
} }
@ -301,7 +313,7 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
@ -527,6 +539,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal // todo update signal
} }
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil { if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones // only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap()) err := e.updateNetworkMap(update.GetNetworkMap())
@ -534,7 +550,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err return err
} }
} }
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil return nil
} }
@ -550,8 +586,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else { } else {
if sshConf.GetSshEnabled() { if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
log.Warnf("running SSH server on Windows is not supported") log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil return nil
} }
// start SSH server if it wasn't running // start SSH server if it wasn't running
@ -624,7 +660,14 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it. // E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() { func (e *Engine) receiveManagementEvents() {
go func() { go func() {
err := e.mgmClient.Sync(e.ctx, e.handleSync) info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client // We want to cancel the operation of the whole client
@ -772,15 +815,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0) routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes { for _, protoRoute := range protoRoutes {
_, prefix, _ := route.ParseNetwork(protoRoute.Network) var prefix netip.Prefix
if len(protoRoute.Domains) == 0 {
var err error
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
continue
}
}
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix,
Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade, Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute,
} }
routes = append(routes, convertedRoute) routes = append(routes, convertedRoute)
} }
@ -1204,7 +1256,8 @@ func (e *Engine) close() {
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
netMap, err := e.mgmClient.GetNetworkMap() info := system.GetInfo(e.ctx)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1430,3 +1483,10 @@ func (e *Engine) startNetworkMonitor() {
} }
}() }()
} }
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@ -78,7 +78,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
ServerSSHAllowed: true, ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@ -212,7 +212,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -221,7 +221,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }
@ -394,7 +394,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client // feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse) updates := make(chan *mgmtProto.SyncResponse)
defer close(updates) defer close(updates)
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error { syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates { for msg := range updates {
err := msgHandler(msg) err := msgHandler(msg)
if err != nil { if err != nil {
@ -409,7 +409,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
@ -568,7 +568,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
if err != nil { if err != nil {
@ -738,7 +738,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm")) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
@ -1009,7 +1009,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort, WgPort: wgPort,
} }
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx e.ctx = ctx
return e, err return e, err
} }

View File

@ -5,8 +5,6 @@ package networkmonitor
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"syscall" "syscall"
"unsafe" "unsafe"
@ -14,10 +12,10 @@ import (
"golang.org/x/net/route" "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
@ -58,7 +56,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
if msg.Flags&unix.IFF_UP != 0 { if msg.Flags&unix.IFF_UP != 0 {
continue continue
} }
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) { if (nexthopv4.Intf == nil || ifinfo.Index != nexthopv4.Intf.Index) && (nexthopv6.Intf == nil || ifinfo.Index != nexthopv6.Intf.Index) {
continue continue
} }
@ -86,7 +84,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback() go callback()
case unix.RTM_DELETE: case unix.RTM_DELETE:
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 { if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback() go callback()
} }
@ -114,7 +112,7 @@ func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
return msg, nil return msg, nil
} }
func parseRouteMessage(buf []byte) (*routemanager.Route, error) { func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err) return nil, fmt.Errorf("parse RIB: %v", err)
@ -129,5 +127,5 @@ func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
} }
return routemanager.MsgToRoute(msg) return systemops.MsgToRoute(msg)
} }

View File

@ -6,14 +6,13 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"runtime/debug" "runtime/debug"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns. // Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
@ -29,23 +28,22 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
nw.wg.Add(1) nw.wg.Add(1)
defer nw.wg.Done() defer nw.wg.Done()
var nexthop4, nexthop6 netip.Addr var nexthop4, nexthop6 systemops.Nexthop
var intf4, intf6 *net.Interface
operation := func() error { operation := func() error {
var errv4, errv6 error var errv4, errv6 error
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified()) nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified()) nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil { if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops") return errors.New("failed to get default next hops")
} }
if errv4 == nil { if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name) log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
} }
if errv6 == nil { if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name) log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
} }
// continue if either route was found // continue if either route was found
@ -65,7 +63,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
} }
}() }()
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil { if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err) return fmt.Errorf("check change: %w", err)
} }

View File

@ -6,16 +6,16 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip"
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
if intfv4 == nil && intfv6 == nil { if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available") return errors.New("no interfaces available")
} }
@ -40,7 +40,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
// handle interface state changes // handle interface state changes
case update := <-linkChan: case update := <-linkChan:
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) { if (nexthopv4.Intf == nil || update.Index != int32(nexthopv4.Intf.Index)) && (nexthopv6.Intf == nil || update.Index != int32(nexthopv6.Intf.Index)) {
continue continue
} }
@ -70,7 +70,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
go callback() go callback()
return nil return nil
case syscall.RTM_DELROUTE: case syscall.RTM_DELROUTE:
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) { if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback() go callback()
return nil return nil

View File

@ -9,7 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
const ( const (
@ -25,18 +25,18 @@ const (
const interval = 10 * time.Second const interval = 10 * time.Second
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
var neighborv4, neighborv6 *routemanager.Neighbor var neighborv4, neighborv6 *systemops.Neighbor
{ {
initialNeighbors, err := getNeighbors() initialNeighbors, err := getNeighbors()
if err != nil { if err != nil {
return fmt.Errorf("get neighbors: %w", err) return fmt.Errorf("get neighbors: %w", err)
} }
if n, ok := initialNeighbors[nexthopv4]; ok { if n, ok := initialNeighbors[nexthopv4.IP]; ok {
neighborv4 = &n neighborv4 = &n
} }
if n, ok := initialNeighbors[nexthopv6]; ok { if n, ok := initialNeighbors[nexthopv6.IP]; ok {
neighborv6 = &n neighborv6 = &n
} }
} }
@ -50,7 +50,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ErrStopped
case <-ticker.C: case <-ticker.C:
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) { if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
go callback() go callback()
return nil return nil
} }
@ -59,12 +59,10 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
} }
func changed( func changed(
nexthopv4 netip.Addr, nexthopv4 systemops.Nexthop,
intfv4 *net.Interface, neighborv4 *systemops.Neighbor,
neighborv4 *routemanager.Neighbor, nexthopv6 systemops.Nexthop,
nexthopv6 netip.Addr, neighborv6 *systemops.Neighbor,
intfv6 *net.Interface,
neighborv6 *routemanager.Neighbor,
) bool { ) bool {
neighbors, err := getNeighbors() neighbors, err := getNeighbors()
if err != nil { if err != nil {
@ -81,7 +79,7 @@ func changed(
return false return false
} }
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) { if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
return true return true
} }
@ -89,20 +87,20 @@ func changed(
} }
// routeChanged checks if the default routes still point to our nexthop/interface // routeChanged checks if the default routes still point to our nexthop/interface
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool { func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[netip.Prefix]systemops.Route) bool {
if !nexthop.IsValid() { if !nexthop.IP.IsValid() {
return false return false
} }
var unspec netip.Prefix var unspec netip.Prefix
if nexthop.Is6() { if nexthop.IP.Is6() {
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0) unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
} else { } else {
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0) unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
} }
if r, ok := routes[unspec]; ok { if r, ok := routes[unspec]; ok {
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 { if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
intf := "<nil>" intf := "<nil>"
if r.Interface != nil { if r.Interface != nil {
intf = r.Interface.Name intf = r.Interface.Name
@ -119,13 +117,13 @@ func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Pref
} }
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool { func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
if neighbor == nil { if neighbor == nil {
return false return false
} }
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces // TODO: consider non-local nexthops, e.g. on point-to-point interfaces
if n, ok := neighbors[nexthop]; ok { if n, ok := neighbors[nexthop.IP]; ok {
if n.State != reachable && n.State != permanent { if n.State != reachable && n.State != permanent {
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
return true return true
@ -150,13 +148,13 @@ func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighb
return false return false
} }
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) { func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
entries, err := routemanager.GetNeighbors() entries, err := systemops.GetNeighbors()
if err != nil { if err != nil {
return nil, fmt.Errorf("get neighbors: %w", err) return nil, fmt.Errorf("get neighbors: %w", err)
} }
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries)) neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
for _, entry := range entries { for _, entry := range entries {
neighbours[entry.IPAddress] = entry neighbours[entry.IPAddress] = entry
} }
@ -164,13 +162,13 @@ func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
return neighbours, nil return neighbours, nil
} }
func getRoutes() (map[netip.Prefix]routemanager.Route, error) { func getRoutes() (map[netip.Prefix]systemops.Route, error) {
entries, err := routemanager.GetRoutes() entries, err := systemops.GetRoutes()
if err != nil { if err != nil {
return nil, fmt.Errorf("get routes: %w", err) return nil, fmt.Errorf("get routes: %w", err)
} }
routes := make(map[netip.Prefix]routemanager.Route, len(entries)) routes := make(map[netip.Prefix]systemops.Route, len(entries))
for _, entry := range entries { for _, entry := range entries {
routes[entry.Destination] = entry routes[entry.Destination] = entry
} }

View File

@ -2,14 +2,17 @@ package peer
import ( import (
"errors" "errors"
"net/netip"
"sync" "sync"
"time" "time"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
) )
// State contains the latest state of a peer // State contains the latest state of a peer
@ -37,25 +40,25 @@ type State struct {
// AddRoute add a single route to routes map // AddRoute add a single route to routes map
func (s *State) AddRoute(network string) { func (s *State) AddRoute(network string) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
if s.routes == nil { if s.routes == nil {
s.routes = make(map[string]struct{}) s.routes = make(map[string]struct{})
} }
s.routes[network] = struct{}{} s.routes[network] = struct{}{}
s.Mux.Unlock()
} }
// SetRoutes set state routes // SetRoutes set state routes
func (s *State) SetRoutes(routes map[string]struct{}) { func (s *State) SetRoutes(routes map[string]struct{}) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
s.routes = routes s.routes = routes
s.Mux.Unlock()
} }
// DeleteRoute removes a route from the network amp // DeleteRoute removes a route from the network amp
func (s *State) DeleteRoute(network string) { func (s *State) DeleteRoute(network string) {
s.Mux.Lock() s.Mux.Lock()
defer s.Mux.Unlock()
delete(s.routes, network) delete(s.routes, network)
s.Mux.Unlock()
} }
// GetRoutes return routes map // GetRoutes return routes map
@ -133,6 +136,7 @@ type Status struct {
rosenpassEnabled bool rosenpassEnabled bool
rosenpassPermissive bool rosenpassPermissive bool
nsGroupStates []NSGroupState nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
// To reduce the number of notification invocation this bool will be true when need to call the notification // To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@ -148,6 +152,7 @@ func NewRecorder(mgmAddress string) *Status {
offlinePeers: make([]State, 0), offlinePeers: make([]State, 0),
notifier: newNotifier(), notifier: newNotifier(),
mgmAddress: mgmAddress, mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
} }
} }
@ -188,7 +193,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey] state, ok := d.peers[peerPubKey]
if !ok { if !ok {
return State{}, errors.New("peer not found") return State{}, iface.ErrPeerNotFound
} }
return state, nil return state, nil
} }
@ -429,6 +434,18 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates d.nsGroupStates = dnsStates
} }
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
}
func (d *Status) GetRosenpassState() RosenpassState { func (d *Status) GetRosenpassState() RosenpassState {
return RosenpassState{ return RosenpassState{
d.rosenpassEnabled, d.rosenpassEnabled,
@ -493,6 +510,12 @@ func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates return d.nsGroupStates
} }
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)
}
// GetFullStatus gets full status // GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus { func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock() d.mux.Lock()

View File

@ -3,19 +3,20 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const minRangeBits = 7
type routerPeerStatus struct { type routerPeerStatus struct {
connected bool connected bool
relayed bool relayed bool
@ -28,33 +29,42 @@ type routesUpdate struct {
routes []*route.Route routes []*route.Route
} }
// RouteHandler defines the interface for handling routes
type RouteHandler interface {
String() string
AddRoute(ctx context.Context) error
RemoveRoute() error
AddAllowedIPs(peerKey string) error
RemoveAllowedIPs() error
}
type clientNetwork struct { type clientNetwork struct {
ctx context.Context ctx context.Context
stop context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan routesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{} routePeersNotifiers map[string]chan struct{}
chosenRoute *route.Route currentChosen *route.Route
network netip.Prefix handler RouteHandler
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork { func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
ctx: ctx, ctx: ctx,
stop: cancel, cancel: cancel,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
network: network, handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
} }
return client return client
} }
@ -86,8 +96,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Metric: Routes with lower metrics (better) are prioritized. // * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred. // * Non-relayed: Routes without relays are preferred.
// * Direct connections: Routes with direct peer connections are favored. // * Direct connections: Routes with direct peer connections are favored.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// * Latency: Routes with lower latency are prioritized. // * Latency: Routes with lower latency are prioritized.
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
// //
// It returns the ID of the selected optimal route. // It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID { func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
@ -96,8 +106,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
currScore := float64(0) currScore := float64(0)
currID := route.ID("") currID := route.ID("")
if c.chosenRoute != nil { if c.currentChosen != nil {
currID = c.chosenRoute.ID currID = c.currentChosen.ID
} }
for _, r := range c.routes { for _, r := range c.routes {
@ -151,18 +161,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
peers = append(peers, r.Peer) peers = append(peers, r.Peer)
} }
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
case chosen != currID: case chosen != currID:
// we compare the current score + 10ms to the chosen score to avoid flapping between routes // we compare the current score + 10ms to the chosen score to avoid flapping between routes
if currScore != 0 && currScore+0.01 > chosenScore { if currScore != 0 && currScore+0.01 > chosenScore {
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore) log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
return currID return currID
} }
var p string var p string
if rt := c.routes[chosen]; rt != nil { if rt := c.routes[chosen]; rt != nil {
p = rt.Peer p = rt.Peer
} }
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network) log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
} }
return chosen return chosen
@ -196,96 +206,101 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
} }
} }
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromWireguardPeer() error {
state, err := c.statusRecorder.GetPeer(peerKey) c.removeStateRoute()
if err != nil {
return fmt.Errorf("get peer state: %v", err)
}
state.DeleteRoute(c.network.String()) if err := c.handler.RemoveAllowedIPs(); err != nil {
if err := c.statusRecorder.UpdatePeerState(state); err != nil { return fmt.Errorf("remove allowed IPs: %w", err)
log.Warnf("Failed to update peer state: %v", err)
}
if state.ConnStatus != peer.StatusConnected {
return nil
}
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
} }
return nil return nil
} }
func (c *clientNetwork) removeRouteFromPeerAndSystem() error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil { if c.currentChosen == nil {
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil { return nil
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
} }
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { var merr *multierror.Error
return fmt.Errorf("remove route: %v", err)
if err := c.removeRouteFromWireguardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
} }
if err := c.handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
} }
return nil
return nberrors.FormatErrorOrNil(merr)
} }
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
routerPeerStatuses := c.getRouterPeerStatuses() routerPeerStatuses := c.getRouterPeerStatuses()
chosen := c.getBestRouteFromStatuses(routerPeerStatuses) newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
// If no route is chosen, remove the route from the peer and system // If no route is chosen, remove the route from the peer and system
if chosen == "" { if newChosenID == "" {
if err := c.removeRouteFromPeerAndSystem(); err != nil { if err := c.removeRouteFromPeerAndSystem(); err != nil {
return fmt.Errorf("remove route from peer and system: %v", err) return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
} }
c.chosenRoute = nil c.currentChosen = nil
return nil return nil
} }
// If the chosen route is the same as the current route, do nothing // If the chosen route is the same as the current route, do nothing
if c.chosenRoute != nil && c.chosenRoute.ID == chosen { if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
if c.chosenRoute.IsEqual(c.routes[chosen]) { c.currentChosen.IsEqual(c.routes[newChosenID]) {
return nil return nil
} }
}
if c.chosenRoute != nil { if c.currentChosen == nil {
// If a previous route exists, remove it from the peer // If they were not previously assigned to another peer, add routes to the system first
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil { if err := c.handler.AddRoute(c.ctx); err != nil {
return fmt.Errorf("remove route from peer: %v", err) return fmt.Errorf("add route: %w", err)
} }
} else { } else {
// otherwise add the route to the system // Otherwise, remove the allowed IPs from the previous peer first
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil { if err := c.removeRouteFromWireguardPeer(); err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
c.network.String(), c.wgInterface.Address().IP.String(), err)
} }
} }
c.chosenRoute = c.routes[chosen] c.currentChosen = c.routes[newChosenID]
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer) if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
c.addStateRoute()
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil { if err != nil {
log.Errorf("Failed to get peer state: %v", err) log.Errorf("Failed to get peer state: %v", err)
} else { return
state.AddRoute(c.network.String()) }
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil { if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err) log.Warnf("Failed to update peer state: %v", err)
} }
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
} }
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil { state.DeleteRoute(c.handler.String())
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", if err := c.statusRecorder.UpdatePeerState(state); err != nil {
c.network, c.chosenRoute.Peer, err) log.Warnf("Failed to update peer state: %v", err)
} }
return nil
} }
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
@ -318,24 +333,23 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
for { for {
select { select {
case <-c.ctx.Done(): case <-c.ctx.Done():
log.Debugf("stopping watcher for network %s", c.network) log.Debugf("Stopping watcher for network [%v]", c.handler)
err := c.removeRouteFromPeerAndSystem() if err := c.removeRouteFromPeerAndSystem(); err != nil {
if err != nil { log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
} }
return return
case <-c.peerStateUpdate: case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem() err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil { if err != nil {
log.Errorf("Couldn't recalculate route and update peer and system: %v", err) log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
} }
case update := <-c.routeUpdate: case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial { if update.updateSerial < c.updateSerial {
log.Warnf("Received a routes update with smaller serial number, ignoring it") log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
continue continue
} }
log.Debugf("Received a new client network route update for %s", c.network) log.Debugf("Received a new client network route update for [%v]", c.handler)
c.handleUpdate(update) c.handleUpdate(update)
@ -343,7 +357,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
err := c.recalculateRouteAndUpdatePeerAndSystem() err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil { if err != nil {
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err) log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
} }
c.startPeersStatusChangeWatcher() c.startPeersStatusChangeWatcher()
@ -351,14 +365,9 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func (c *clientNetwork) getAsInterface() *net.Interface { func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
intf, err := net.InterfaceByName(c.wgInterface.Name()) if rt.IsDynamic() {
if err != nil { return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
intf = &net.Interface{
Name: c.wgInterface.Name(),
} }
} return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
return intf
} }

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -340,9 +341,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
// create new clientNetwork // create new clientNetwork
client := &clientNetwork{ client := &clientNetwork{
network: netip.MustParsePrefix("192.168.0.0/24"), handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
routes: tc.existingRoutes, routes: tc.existingRoutes,
chosenRoute: currentRoute, currentChosen: currentRoute,
} }
chosenRoute := client.getBestRouteFromStatuses(tc.statuses) chosenRoute := client.getBestRouteFromStatuses(tc.statuses)

View File

@ -0,0 +1,361 @@
package dynamic
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
DefaultInterval = time.Minute
minInterval = 2 * time.Second
addAllowedIP = "add allowed IP %s: %w"
)
type domainMap map[domain.Domain][]netip.Prefix
type resolveResult struct {
domain domain.Domain
prefix netip.Prefix
err error
}
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
interval time.Duration
dynamicDomains domainMap
mu sync.Mutex
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
}
func NewRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
}
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
}
func (r *Route) AddRoute(ctx context.Context) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
ctx, r.cancel = context.WithCancel(ctx)
go r.startResolver(ctx)
return nil
}
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
func (r *Route) RemoveRoute() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.cancel != nil {
r.cancel()
}
var merr *multierror.Error
for domain, prefixes := range r.dynamicDomains {
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
}
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
r.statusRecorder.DeleteResolvedDomainsStates(domain)
}
r.dynamicDomains = domainMap{}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) AddAllowedIPs(peerKey string) error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for domain, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
}
r.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) RemoveAllowedIPs() error {
r.mu.Lock()
defer r.mu.Unlock()
var merr *multierror.Error
for _, domainPrefixes := range r.dynamicDomains {
for _, prefix := range domainPrefixes {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
}
r.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) startResolver(ctx context.Context) {
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
interval := r.interval
if interval < minInterval {
interval = minInterval
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
r.update(ctx)
for {
select {
case <-ctx.Done():
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
return
case <-ticker.C:
r.update(ctx)
}
}
}
func (r *Route) update(ctx context.Context) {
if resolved, err := r.resolveDomains(); err != nil {
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
log.Errorf("Failed to update dynamic routes for [%v]: %v", r, err)
}
}
func (r *Route) resolveDomains() (domainMap, error) {
results := make(chan resolveResult)
go r.resolve(results)
resolved := domainMap{}
var merr *multierror.Error
for result := range results {
if result.err != nil {
merr = multierror.Append(merr, result.err)
} else {
resolved[result.domain] = append(resolved[result.domain], result.prefix)
}
}
return resolved, nberrors.FormatErrorOrNil(merr)
}
func (r *Route) resolve(results chan resolveResult) {
var wg sync.WaitGroup
for _, d := range r.route.Domains {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
return
}
results <- resolveResult{domain: domain, prefix: prefix}
}
}(d)
}
wg.Wait()
close(results)
}
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
r.mu.Lock()
defer r.mu.Unlock()
if ctx.Err() != nil {
return ctx.Err()
}
var merr *multierror.Error
for domain, newPrefixes := range newDomains {
oldPrefixes := r.dynamicDomains[domain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
addedPrefixes, err := r.addRoutes(domain, toAdd)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(addedPrefixes) > 0 {
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
}
removedPrefixes, err := r.removeRoutes(toRemove)
if err != nil {
merr = multierror.Append(merr, err)
} else if len(removedPrefixes) > 0 {
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
}
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
var addedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
if r.currentPeerKey != "" {
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
}
}
addedPrefixes = append(addedPrefixes, prefix)
}
return addedPrefixes, merr.ErrorOrNil()
}
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
if r.route.KeepRoute {
return nil, nil
}
var removedPrefixes []netip.Prefix
var merr *multierror.Error
for _, prefix := range prefixes {
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
}
if r.currentPeerKey != "" {
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
}
}
removedPrefixes = append(removedPrefixes, prefix)
}
return removedPrefixes, merr.ErrorOrNil()
}
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
return fmt.Errorf(addAllowedIP, prefix, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
prefixSet := make(map[netip.Prefix]struct{})
for _, prefix := range oldPrefixes {
prefixSet[prefix] = struct{}{}
}
for _, prefix := range removedPrefixes {
delete(prefixSet, prefix)
}
for _, prefix := range addedPrefixes {
prefixSet[prefix] = struct{}{}
}
var combinedPrefixes []netip.Prefix
for prefix := range prefixSet {
combinedPrefixes = append(combinedPrefixes, prefix)
}
return combinedPrefixes
}

View File

@ -2,18 +2,23 @@ package routemanager
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"net/url" "net/url"
"runtime" "runtime"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -21,11 +26,6 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
// nolint:unused
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
@ -46,25 +46,65 @@ type DefaultManager struct {
clientNetworks map[route.HAUniqueID]*clientNetwork clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter serverRouter serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface *iface.WGIface wgInterface *iface.WGIface
pubKey string pubKey string
notifier *notifier notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
} }
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface *iface.WGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
sysOps := systemops.NewSysOps(wgInterface)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork), clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(), routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
notifier: newNotifier(), notifier: newNotifier(),
} }
dm.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ any) (any, error) {
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ any) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
dm.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
}
return nil
},
)
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes) cr := dm.clientRoutes(initialRoutes)
dm.notifier.setInitialClientRoutes(cr) dm.notifier.setInitialClientRoutes(cr)
@ -78,7 +118,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee
return nil, nil, nil return nil, nil, nil
} }
if err := cleanupRouting(); err != nil { if err := m.sysOps.CleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err) log.Warnf("Failed cleaning up routing: %v", err)
} }
@ -86,7 +126,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee
signalAddress := m.statusRecorder.GetSignalState().URL signalAddress := m.statusRecorder.GetSignalState().URL
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err) return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
@ -110,8 +150,19 @@ func (m *DefaultManager) Stop() {
m.serverRouter.cleanUp() m.serverRouter.cleanUp()
} }
if m.routeRefCounter != nil {
if err := m.routeRefCounter.Flush(); err != nil {
log.Errorf("Error flushing route ref counter: %v", err)
}
}
if m.allowedIPsRefCounter != nil {
if err := m.allowedIPsRefCounter.Flush(); err != nil {
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
}
}
if !nbnet.CustomRoutingDisabled() { if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil { if err := m.sysOps.CleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err) log.Errorf("Error cleaning up routing: %v", err)
} else { } else {
log.Info("Routing cleanup complete") log.Info("Routing cleanup complete")
@ -185,7 +236,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue continue
} }
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@ -197,7 +248,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
for id, client := range m.clientNetworks { for id, client := range m.clientNetworks {
if _, ok := networks[id]; !ok { if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id) log.Debugf("Stopping client network watcher, %s", id)
client.stop() client.cancel()
delete(m.clientNetworks, id) delete(m.clientNetworks, id)
} }
} }
@ -210,7 +261,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }
@ -228,7 +279,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
ownNetworkIDs := make(map[route.HAUniqueID]bool) ownNetworkIDs := make(map[route.HAUniqueID]bool)
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
haID := route.GetHAUniqueID(newRoute) haID := newRoute.GetHAUniqueID()
if newRoute.Peer == m.pubKey { if newRoute.Peer == m.pubKey {
ownNetworkIDs[haID] = true ownNetworkIDs[haID] = true
// only linux is supported for now // only linux is supported for now
@ -241,9 +292,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
haID := route.GetHAUniqueID(newRoute) haID := newRoute.GetHAUniqueID()
if !ownNetworkIDs[haID] { if !ownNetworkIDs[haID] {
if !isPrefixSupported(newRoute.Network) { if !isRouteSupported(newRoute) {
continue continue
} }
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute) newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
@ -255,23 +306,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes) _, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0) rs := make([]*route.Route, len(crMap))
for _, routes := range crMap { for _, routes := range crMap {
rs = append(rs, routes...) rs = append(rs, routes...)
} }
return rs return rs
} }
func isPrefixSupported(prefix netip.Prefix) bool { func isRouteSupported(route *route.Route) bool {
if !nbnet.CustomRoutingDisabled() { if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
return true return true
} }
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported // If prefix is too small, lets assume it is a possible default prefix which is not yet supported
// we skip this prefix management // we skip this prefix management
if prefix.Bits() <= minRangeBits { if route.Network.Bits() <= vars.MinRangeBits {
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
version.NetbirdVersion(), prefix) version.NetbirdVersion(), route.Network)
return false return false
} }
return true return true

View File

@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm") statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO() ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
_, _, err = routeManager.Init() _, _, err = routeManager.Init()

View File

@ -0,0 +1,155 @@
package refcounter
import (
"errors"
"fmt"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
var ErrIgnore = errors.New("ignore")
type Ref[O any] struct {
Count int
Out O
}
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
type Counter[I, O any] struct {
// refCountMap keeps track of the reference Ref for prefixes
refCountMap map[netip.Prefix]Ref[O]
refCountMu sync.Mutex
// idMap keeps track of the prefixes associated with an ID for removal
idMap map[string][]netip.Prefix
idMu sync.Mutex
add AddFunc[I, O]
remove RemoveFunc[I, O]
}
// New creates a new Counter instance
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
return &Counter[I, O]{
refCountMap: map[netip.Prefix]Ref[O]{},
idMap: map[string][]netip.Prefix{},
add: add,
remove: remove,
}
}
// Increment increments the reference count for the given prefix.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref := rm.refCountMap[prefix]
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
// Call AddFunc only if it's a new prefix
if ref.Count == 0 {
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
out, err := rm.add(prefix, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
}
ref.Out = out
}
ref.Count++
rm.refCountMap[prefix] = ref
return ref, nil
}
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
// If this is the first reference to the prefix, the AddFunc is called.
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.Increment(prefix, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
rm.idMap[id] = append(rm.idMap[id], prefix)
return ref, nil
}
// Decrement decrements the reference count for the given prefix.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[prefix]
if !ok {
log.Tracef("No reference found for prefix %s", prefix)
return ref, nil
}
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
if ref.Count == 1 {
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
if err := rm.remove(prefix, ref.Out); err != nil {
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
}
delete(rm.refCountMap, prefix)
} else {
ref.Count--
rm.refCountMap[prefix] = ref
}
return ref, nil
}
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, prefix := range rm.idMap[id] {
if _, err := rm.Decrement(prefix); err != nil {
merr = multierror.Append(merr, err)
}
}
delete(rm.idMap, id)
return nberrors.FormatErrorOrNil(merr)
}
// Flush removes all references and calls RemoveFunc for each prefix.
func (rm *Counter[I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for prefix := range rm.refCountMap {
log.Tracef("Removing for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.remove(prefix, ref.Out); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]Ref[O]{}
rm.idMap = map[string][]netip.Prefix{}
return nberrors.FormatErrorOrNil(merr)
}

View File

@ -0,0 +1,7 @@
package refcounter
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
type RouteRefCounter = Counter[any, any]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
type AllowedIPsRefCounter = Counter[string, string]

View File

@ -1,127 +0,0 @@
//go:build !android && !ios
package routemanager
import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
type ref struct {
count int
nexthop netip.Addr
intf *net.Interface
}
type RouteManager struct {
// refCountMap keeps track of the reference ref for prefixes
refCountMap map[netip.Prefix]ref
// prefixMap keeps track of the prefixes associated with a connection ID for removal
prefixMap map[nbnet.ConnectionID][]netip.Prefix
addRoute AddRouteFunc
removeRoute RemoveRouteFunc
mutex sync.Mutex
}
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
// TODO: read initial routing table into refCountMap
return &RouteManager{
refCountMap: map[netip.Prefix]ref{},
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
addRoute: addRoute,
removeRoute: removeRoute,
}
}
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
ref := rm.refCountMap[prefix]
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
// Add route to the system, only if it's a new prefix
if ref.count == 0 {
log.Debugf("Adding route for prefix %s", prefix)
nexthop, intf, err := rm.addRoute(prefix)
if errors.Is(err, ErrRouteNotFound) {
return nil
}
if errors.Is(err, ErrRouteNotAllowed) {
log.Debugf("Adding route for prefix %s: %s", prefix, err)
}
if err != nil {
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
}
ref.nexthop = nexthop
ref.intf = intf
}
ref.count++
rm.refCountMap[prefix] = ref
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
return nil
}
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
prefixes, ok := rm.prefixMap[connID]
if !ok {
log.Debugf("No prefixes found for connection ID %s", connID)
return nil
}
var result *multierror.Error
for _, prefix := range prefixes {
ref := rm.refCountMap[prefix]
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
if ref.count == 1 {
log.Debugf("Removing route for prefix %s", prefix)
// TODO: don't fail if the route is not found
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
continue
}
delete(rm.refCountMap, prefix)
} else {
ref.count--
rm.refCountMap[prefix] = ref
}
}
delete(rm.prefixMap, connID)
return result.ErrorOrNil()
}
// Flush removes all references and routes from the system
func (rm *RouteManager) Flush() error {
rm.mutex.Lock()
defer rm.mutex.Unlock()
var result *multierror.Error
for prefix := range rm.refCountMap {
log.Debugf("Removing route for prefix %s", prefix)
ref := rm.refCountMap[prefix]
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
}
}
rm.refCountMap = map[netip.Prefix]ref{}
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
return result.ErrorOrNil()
}

View File

@ -5,13 +5,14 @@ package routemanager
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip" "net"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -70,7 +71,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
} }
if len(m.routes) > 0 { if len(m.routes) > 0 {
err := enableIPForwarding() err := systemops.EnableIPForwarding()
if err != nil { if err != nil {
return err return err
} }
@ -88,7 +89,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) routerPair, err := routeToRouterPair(m.wgInterface.Address().Network, route)
if err != nil { if err != nil {
return fmt.Errorf("parse prefix: %w", err) return fmt.Errorf("parse prefix: %w", err)
} }
@ -117,7 +118,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route) routerPair, err := routeToRouterPair(m.wgInterface.Address().Network, route)
if err != nil { if err != nil {
return fmt.Errorf("parse prefix: %w", err) return fmt.Errorf("parse prefix: %w", err)
} }
@ -133,7 +134,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
if state.Routes == nil { if state.Routes == nil {
state.Routes = map[string]struct{}{} state.Routes = map[string]struct{}{}
} }
state.Routes[route.Network.String()] = struct{}{}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.UpdateLocalPeerState(state)
return nil return nil
@ -144,7 +151,7 @@ func (m *defaultServerRouter) cleanUp() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
for _, r := range m.routes { for _, r := range m.routes {
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r) routerPair, err := routeToRouterPair(m.wgInterface.Address().Network, r)
if err != nil { if err != nil {
log.Errorf("Failed to convert route to router pair: %v", err) log.Errorf("Failed to convert route to router pair: %v", err)
continue continue
@ -162,15 +169,17 @@ func (m *defaultServerRouter) cleanUp() {
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.UpdateLocalPeerState(state)
} }
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { func routeToRouterPair(source *net.IPNet, route *route.Route) (firewall.RouterPair, error) {
parsed, err := netip.ParsePrefix(source) destination := route.Network.Masked().String()
if err != nil { if route.IsDynamic() {
return firewall.RouterPair{}, err // TODO: add ipv6
destination = "0.0.0.0/0"
} }
return firewall.RouterPair{ return firewall.RouterPair{
ID: string(route.ID), ID: string(route.ID),
Source: parsed.String(), Source: source.String(),
Destination: route.Network.Masked().String(), Destination: destination,
Masquerade: route.Masquerade, Masquerade: route.Masquerade,
}, nil }, nil
} }

View File

@ -0,0 +1,57 @@
package static
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type Route struct {
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
}
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
return &Route{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
}
}
// Route route methods
func (r *Route) String() string {
return r.route.Network.String()
}
func (r *Route) AddRoute(context.Context) error {
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
return err
}
func (r *Route) RemoveRoute() error {
_, err := r.routeRefCounter.Decrement(r.route.Network)
return err
}
func (r *Route) AddAllowedIPs(peerKey string) error {
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
r.route.Network,
ref.Out,
)
}
return nil
}
func (r *Route) RemoveAllowedIPs() error {
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
return err
}

View File

@ -0,0 +1,103 @@
// go:build !android
package sysctl
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/iface"
)
const (
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := Set(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = Set(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := Set(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, nberrors.FormatErrorOrNil(result)
}
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
// Cleanup resets sysctl settings to their original values.
func Cleanup(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := Set(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@ -0,0 +1,18 @@
//go:build darwin || dragonfly || netbsd || openbsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
return false
}

View File

@ -0,0 +1,19 @@
//go:build: freebsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
return false
}

View File

@ -0,0 +1,27 @@
package systemops
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
IP netip.Addr
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}
}

View File

@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd //go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager package systemops
import ( import (
"errors" "errors"
@ -43,8 +43,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type) return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
} }
if m.Flags&syscall.RTF_UP == 0 || if filterRoutesByFlags(m.Flags) {
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
continue continue
} }
@ -101,6 +100,7 @@ func toNetIP(a route.Addr) netip.Addr {
} }
} }
// ones returns the number of leading ones in the mask.
func ones(a route.Addr) (int, error) { func ones(a route.Addr) (int, error) {
switch t := a.(type) { switch t := a.(type) {
case *route.Inet4Addr: case *route.Inet4Addr:
@ -114,6 +114,7 @@ func ones(a route.Addr) (int, error) {
} }
} }
// MsgToRoute converts a route message to a Route.
func MsgToRoute(msg *route.RouteMessage) (*Route, error) { func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2] dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]

View File

@ -1,6 +1,6 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd //go:build darwin || dragonfly || freebsd || netbsd || openbsd
package routemanager package systemops
import ( import (
"testing" "testing"

View File

@ -1,6 +1,6 @@
//go:build !ios //go:build !ios
package routemanager package systemops
import ( import (
"fmt" "fmt"
@ -35,13 +35,15 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0") baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"} intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
wg.Add(1) wg.Add(1)
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil { if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err) t.Errorf("Failed to add route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)
@ -57,7 +59,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) { go func(ip netip.Addr) {
defer wg.Done() defer wg.Done()
prefix := netip.PrefixFrom(ip, 32) prefix := netip.PrefixFrom(ip, 32)
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil { if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err) t.Errorf("Failed to remove route for %s: %v", prefix, err)
} }
}(baseIP) }(baseIP)

View File

@ -1,6 +1,6 @@
//go:build !android && !ios //go:build !android && !ios
package routemanager package systemops
import ( import (
"context" "context"
@ -15,7 +15,11 @@ import (
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -25,29 +29,75 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRouteNotFound = errors.New("route not found") func (r *SysOps) setupRefCounter(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
var ErrRouteNotAllowed = errors.New("route not allowed") initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
refCounter := refcounter.New(
func(prefix netip.Prefix, _ any) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
}
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
},
r.removeFromRouteTable,
)
r.refCounter = refCounter
return r.setupHooks(initAddresses)
}
func (r *SysOps) cleanupRefCounter() error {
if r.refCounter == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
// TODO: fix: for default our wg address now appears as the default gw // TODO: fix: for default our wg address now appears as the default gw
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified() addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified() addr = netip.IPv6Unspecified()
} }
defaultGateway, _, err := GetNextHop(addr) nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, ErrRouteNotFound) { if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err) return fmt.Errorf("get existing route gateway: %s", err)
} }
if !prefix.Contains(defaultGateway) { if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil return nil
} }
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if defaultGateway.Is6() { if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
} }
ok, err := existsInRouteTable(gatewayPrefix) ok, err := existsInRouteTable(gatewayPrefix)
@ -60,46 +110,264 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
return nil return nil
} }
gatewayHop, intf, err := GetNextHop(defaultGateway) nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, ErrRouteNotFound) { if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
} }
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return addToRouteTable(gatewayPrefix, gatewayHop, intf) return r.addToRouteTable(gatewayPrefix, nexthop)
} }
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
exitNextHop := Nexthop{
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(result)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}
func GetNextHop(ip netip.Addr) (Nexthop, error) {
r, err := netroute.New() r, err := netroute.New()
if err != nil { if err != nil {
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) return Nexthop{}, fmt.Errorf("new netroute: %w", err)
} }
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil { if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err) log.Debugf("Failed to get route for %s: %v", ip, err)
return netip.Addr{}, nil, ErrRouteNotFound return Nexthop{}, vars.ErrRouteNotFound
} }
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil { if gateway == nil {
if preferredSrc == nil { if runtime.GOOS == "freebsd" {
return netip.Addr{}, nil, ErrRouteNotFound return Nexthop{Intf: intf}, nil
} }
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf) addr, err := ipToAddr(preferredSrc, intf)
if err != nil { if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err) return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
} }
return addr.Unmap(), intf, nil return Nexthop{
IP: addr.Unmap(),
Intf: intf,
}, nil
} }
addr, err := ipToAddr(gateway, intf) addr, err := ipToAddr(gateway, intf)
if err != nil { if err != nil {
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err) return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
} }
return addr, intf, nil return Nexthop{
IP: addr,
Intf: intf,
}, nil
} }
// converts a net.IP to a netip.Addr including the zone based on the passed interface // converts a net.IP to a netip.Addr including the zone based on the passed interface
@ -140,275 +408,9 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
return false, fmt.Errorf("get routes from table: %w", err) return false, fmt.Errorf("get routes from table: %w", err)
} }
for _, tableRoute := range routes { for _, tableRoute := range routes {
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil return true, nil
} }
} }
return false, nil return false, nil
} }
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return netip.Addr{}, nil, ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, intf, err := GetNextHop(addr)
if err != nil {
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
exitNextHop := nexthop
exitIntf := intf
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
exitIntf = initialIntf
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, exitIntf, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
return err
}
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == defaultv6 {
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
err := addRouteForCurrentDefaultGateway(prefix)
if err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return addToRouteTable(prefix, netip.Addr{}, intf)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if prefix == defaultv4 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
} else if prefix == defaultv6 {
var result *multierror.Error
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
result = multierror.Append(result, err)
}
return result.ErrorOrNil()
}
return removeFromRouteTable(prefix, netip.Addr{}, intf)
}
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return &prefix, nil
}
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
*routeManager = NewRouteManager(
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
addr := prefix.Addr()
nexthop, intf := initialNextHopV4, initialIntfV4
if addr.Is6() {
nexthop, intf = initialNextHopV6, initialIntfV6
}
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
},
removeFromRouteTable,
)
return setupHooks(*routeManager, initAddresses)
}
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
if routeManager == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := routeManager.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := getPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := routeManager.RemoveRouteRef(connID); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return result.ErrorOrNil()
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}

View File

@ -1,6 +1,6 @@
//go:build !android && !ios //go:build !android && !ios
package routemanager package systemops
import ( import (
"bytes" "bytes"
@ -63,17 +63,20 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, wgInterface)
r := NewSysOps(wgInterface)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, cleanupRouting()) assert.NoError(t, r.CleanupRouting())
}) })
index, err := net.InterfaceByName(wgInterface.Name()) index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = addVPNRoute(testCase.prefix, intf) err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err") require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard { if testCase.shouldRouteToWireguard {
@ -84,19 +87,19 @@ func TestAddRemoveRoutes(t *testing.T) {
exists, err := existsInRouteTable(testCase.prefix) exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err") require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard { if exists && testCase.shouldRouteToWireguard {
err = removeVPNRoute(testCase.prefix, intf) err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err") require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr()) prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err") require.NoError(t, err, "GetNextHop should not return err")
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err) require.NoError(t, err)
if testCase.shouldBeRemoved { if testCase.shouldBeRemoved {
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else { } else {
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
} }
} }
}) })
@ -104,11 +107,11 @@ func TestAddRemoveRoutes(t *testing.T) {
} }
func TestGetNextHop(t *testing.T) { func TestGetNextHop(t *testing.T) {
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil { if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err) t.Fatal("shouldn't return error when fetching the gateway: ", err)
} }
if !gateway.IsValid() { if !nexthop.IP.IsValid() {
t.Fatal("should return a gateway") t.Fatal("should return a gateway")
} }
addresses, err := net.InterfaceAddrs() addresses, err := net.InterfaceAddrs()
@ -130,24 +133,24 @@ func TestGetNextHop(t *testing.T) {
} }
} }
localIP, _, err := GetNextHop(testingPrefix.Addr()) localIP, err := GetNextHop(testingPrefix.Addr())
if err != nil { if err != nil {
t.Fatal("shouldn't return error: ", err) t.Fatal("shouldn't return error: ", err)
} }
if !localIP.IsValid() { if !localIP.IP.IsValid() {
t.Fatal("should return a gateway for local network") t.Fatal("should return a gateway for local network")
} }
if localIP.String() == gateway.String() { if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local ip should not match with gateway IP") t.Fatal("local IP should not match with gateway IP")
} }
if localIP.String() != testingIP { if localIP.IP.String() != testingIP {
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
} }
} }
func TestAddExistAndRemoveRoute(t *testing.T) { func TestAddExistAndRemoveRoute(t *testing.T) {
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0")) defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultGateway: ", defaultGateway) t.Log("defaultNexthop: ", defaultNexthop)
if err != nil { if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err) t.Fatal("shouldn't return error when fetching the gateway: ", err)
} }
@ -164,7 +167,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
}, },
{ {
name: "Should Not Add Route if overlaps with default gateway", name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false, shouldAddRoute: false,
}, },
{ {
@ -214,14 +217,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
// Prepare the environment // Prepare the environment
if testCase.preExistingPrefix.IsValid() { if testCase.preExistingPrefix.IsValid() {
err := addVPNRoute(testCase.preExistingPrefix, intf) err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route") require.NoError(t, err, "should not return err when adding pre-existing route")
} }
// Add the route // Add the route
err = addVPNRoute(testCase.prefix, intf) err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route") require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute { if testCase.shouldAddRoute {
@ -231,7 +236,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.True(t, ok, "route should exist") require.True(t, ok, "route should exist")
// remove route again if added // remove route again if added
err = removeVPNRoute(testCase.prefix, intf) err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err") require.NoError(t, err, "should not return err")
} }
@ -343,65 +348,52 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
return wgInterface return wgInterface
} }
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func setupTestEnv(t *testing.T) { func setupTestEnv(t *testing.T) {
t.Helper() t.Helper()
setupDummyInterfacesAndRoutes(t) setupDummyInterfacesAndRoutes(t)
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, wgIface.Close()) assert.NoError(t, wgInterface.Close())
}) })
_, _, err := setupRouting(nil, wgIface) r := NewSysOps(wgInterface)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, cleanupRouting()) assert.NoError(t, r.CleanupRouting())
}) })
index, err := net.InterfaceByName(wgIface.Name()) index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err") require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()} intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// default route exists in main table and vpn table // default route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.0.0.0/8 route exists in main table and vpn table // 10.0.0.0/8 route exists in main table and vpn table
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 10.10.0.0/24 more specific route exists in vpn table // 10.10.0.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// 127.0.10.0/24 more specific route exists in vpn table // 127.0.10.0/24 more specific route exists in vpn table
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
// unique route in vpn table // unique route in vpn table
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf) setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
} }
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
@ -410,11 +402,11 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
return return
} }
prefixGateway, _, err := GetNextHop(prefix.Addr()) prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err") require.NoError(t, err, "GetNextHop should not return err")
if invert { if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else { } else {
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
} }
} }

View File

@ -1,6 +1,6 @@
//go:build !android //go:build !android
package routemanager package systemops
import ( import (
"bufio" "bufio"
@ -9,16 +9,16 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strconv"
"strings"
"syscall" "syscall"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@ -33,16 +33,10 @@ const (
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting. // ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward" ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
) )
var ErrTableIDExists = errors.New("ID exists with different name") var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified // originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int var originalSysctl map[string]int
@ -82,7 +76,7 @@ func getSetupRules() []ruleParams {
} }
} }
// setupRouting establishes the routing configuration for the VPN, including essential rules // SetupRouting establishes the routing configuration for the VPN, including essential rules
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic. // to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
// //
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over // Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
@ -92,17 +86,17 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy() { if isLegacy() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) return r.setupRefCounter(initAddresses)
} }
if err = addRoutingTableName(); err != nil { if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err) log.Errorf("Error adding routing table name: %v", err)
} }
originalValues, err := setupSysctl(wgIface) originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil { if err != nil {
log.Errorf("Error setting up sysctl: %v", err) log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true sysctlFailed = true
@ -111,7 +105,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil { if cleanErr := r.CleanupRouting(); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr) log.Errorf("Error cleaning up routing: %v", cleanErr)
} }
} }
@ -123,7 +117,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
if errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true) setIsLegacy(true)
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) return r.setupRefCounter(initAddresses)
} }
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
} }
@ -132,12 +126,12 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
return nil, nil, nil return nil, nil, nil
} }
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state. // It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process. // The function uses error aggregation to report any errors encountered during the cleanup process.
func cleanupRouting() error { func (r *SysOps) CleanupRouting() error {
if isLegacy() { if isLegacy() {
return cleanupRoutingWithRouteManager(routeManager) return r.cleanupRefCounter()
} }
var result *multierror.Error var result *multierror.Error
@ -156,58 +150,58 @@ func cleanupRouting() error {
} }
} }
if err := cleanupSysctl(originalSysctl); err != nil { if err := sysctl.Cleanup(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err)) result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
} }
originalSysctl = nil originalSysctl = nil
sysctlFailed = false sysctlFailed = false
return result.ErrorOrNil() return nberrors.FormatErrorOrNil(result)
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
} }
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
} }
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if isLegacy() {
return genericAddVPNRoute(prefix, intf) return r.genericAddVPNRoute(prefix, intf)
} }
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) { if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)") log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
} }
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support // TODO remove this once we have ipv6 support
if prefix == defaultv4 { if prefix == vars.Defaultv4 {
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err) return fmt.Errorf("add blackhole: %w", err)
} }
} }
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err) return fmt.Errorf("add route: %w", err)
} }
return nil return nil
} }
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() { if isLegacy() {
return genericRemoveVPNRoute(prefix, intf) return r.genericRemoveVPNRoute(prefix, intf)
} }
// TODO remove this once we have ipv6 support // TODO remove this once we have ipv6 support
if prefix == defaultv4 { if prefix == vars.Defaultv4 {
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err) return fmt.Errorf("remove unreachable route: %w", err)
} }
} }
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err) return fmt.Errorf("remove route: %w", err)
} }
return nil return nil
@ -255,7 +249,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
} }
// addRoute adds a route to a specific routing table identified by tableID. // addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
route := &netlink.Route{ route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE, Scope: netlink.SCOPE_UNIVERSE,
Table: tableID, Table: tableID,
@ -268,7 +262,7 @@ func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID
} }
route.Dst = ipNet route.Dst = ipNet
if err := addNextHop(addr, intf, route); err != nil { if err := addNextHop(nexthop, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err) return fmt.Errorf("add gateway and device: %w", err)
} }
@ -327,7 +321,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
} }
// removeRoute removes a route from a specific routing table identified by tableID. // removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error { func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String()) _, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil { if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err) return fmt.Errorf("parse prefix %s: %w", prefix, err)
@ -340,7 +334,7 @@ func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tabl
Dst: ipNet, Dst: ipNet,
} }
if err := addNextHop(addr, intf, route); err != nil { if err := addNextHop(nexthop, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err) return fmt.Errorf("add gateway and device: %w", err)
} }
@ -373,11 +367,11 @@ func flushRoutes(tableID, family int) error {
} }
} }
return result.ErrorOrNil() return nberrors.FormatErrorOrNil(result)
} }
func enableIPForwarding() error { func EnableIPForwarding() error {
_, err := setSysctl(ipv4ForwardingPath, 1, false) _, err := sysctl.Set(ipv4ForwardingPath, 1, false)
return err return err
} }
@ -481,19 +475,19 @@ func removeRule(params ruleParams) error {
} }
// addNextHop adds the gateway and device to the route. // addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error { func addNextHop(nexthop Nexthop, route *netlink.Route) error {
if intf != nil { if nexthop.Intf != nil {
route.LinkIndex = intf.Index route.LinkIndex = nexthop.Intf.Index
} }
if addr.IsValid() { if nexthop.IP.IsValid() {
route.Gw = addr.AsSlice() route.Gw = nexthop.IP.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index // if zone is set, it means the gateway is a link-local address, so we set the link index
if addr.Zone() != "" && intf == nil { if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
link, err := netlink.LinkByName(addr.Zone()) link, err := netlink.LinkByName(nexthop.IP.Zone())
if err != nil { if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err) return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
} }
route.LinkIndex = link.Attrs().Index route.LinkIndex = link.Attrs().Index
} }
@ -508,83 +502,3 @@ func getAddressFamily(prefix netip.Prefix) int {
} }
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@ -1,6 +1,6 @@
//go:build !android //go:build !android
package routemanager package systemops
import ( import (
"errors" "errors"
@ -14,6 +14,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
) )
var expectedVPNint = "wgtest0" var expectedVPNint = "wgtest0"
@ -138,7 +140,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
if dstIPNet.String() == "0.0.0.0/0" { if dstIPNet.String() == "0.0.0.0/0" {
var err error var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, ErrRouteNotFound) { if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err) t.Logf("Failed to fetch original gateway: %v", err)
} }
@ -193,7 +195,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
} }
} }
return nil, 0, ErrRouteNotFound return nil, 0, vars.ErrRouteNotFound
} }
func setupDummyInterfacesAndRoutes(t *testing.T) { func setupDummyInterfacesAndRoutes(t *testing.T) {

View File

@ -0,0 +1,34 @@
//go:build ios || android
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
)
func (r *SysOps) SetupRouting([]net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
return nil
}
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}

View File

@ -0,0 +1,24 @@
//go:build !linux && !ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericRemoveVPNRoute(prefix, intf)
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}

View File

@ -1,6 +1,6 @@
//go:build darwin && !ios //go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
package routemanager package systemops
import ( import (
"fmt" "fmt"
@ -14,42 +14,40 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
) )
var routeManager *RouteManager func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return r.setupRefCounter(initAddresses)
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
func cleanupRouting() error { func (r *SysOps) CleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager) return r.cleanupRefCounter()
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return routeCmd("add", prefix, nexthop, intf) return r.routeCmd("add", prefix, nexthop)
} }
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return routeCmd("delete", prefix, nexthop, intf) return r.routeCmd("delete", prefix, nexthop)
} }
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet" inet := "-inet"
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
inet = "-inet6" inet = "-inet6"
} }
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network} args := []string{"-n", action, inet, network}
if nexthop.IsValid() { if nexthop.IP.IsValid() {
args = append(args, nexthop.Unmap().String()) args = append(args, nexthop.IP.Unmap().String())
} else if intf != nil { } else if nexthop.Intf != nil {
args = append(args, "-interface", intf.Name) args = append(args, "-interface", nexthop.Intf.Name)
} }
if err := retryRouteCmd(args); err != nil { if err := retryRouteCmd(args); err != nil {

View File

@ -1,6 +1,6 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly //go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package routemanager package systemops
import ( import (
"fmt" "fmt"

View File

@ -1,6 +1,6 @@
//go:build windows //go:build windows
package routemanager package systemops
import ( import (
"fmt" "fmt"
@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
) )
type MSFT_NetRoute struct { type MSFT_NetRoute struct {
@ -57,14 +56,43 @@ var prefixList []netip.Prefix
var lastUpdate time.Time var lastUpdate time.Time
var mux = sync.Mutex{} var mux = sync.Mutex{}
var routeManager *RouteManager func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return r.setupRefCounter(initAddresses)
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
} }
func cleanupRouting() error { func (r *SysOps) CleanupRouting() error {
return cleanupRoutingWithRouteManager(routeManager) return r.cleanupRefCounter()
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
nexthop.Intf = &net.Interface{Index: zone}
nexthop.IP.WithZone("")
}
return addRouteCmd(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
nexthop.IP.WithZone("")
args = append(args, nexthop.IP.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
} }
func getRoutesFromTable() ([]netip.Prefix, error) { func getRoutesFromTable() ([]netip.Prefix, error) {
@ -93,7 +121,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
func GetRoutes() ([]Route, error) { func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute var entries []MSFT_NetRoute
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute` query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("get routes: %w", err) return nil, fmt.Errorf("get routes: %w", err)
} }
@ -157,11 +185,11 @@ func GetNeighbors() ([]Neighbor, error) {
return neighbors, nil return neighbors, nil
} }
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error { func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()} args := []string{"add", prefix.String()}
if nexthop.IsValid() { if nexthop.IP.IsValid() {
args = append(args, nexthop.Unmap().String()) args = append(args, nexthop.IP.Unmap().String())
} else { } else {
addr := "0.0.0.0" addr := "0.0.0.0"
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
@ -170,8 +198,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
args = append(args, addr) args = append(args, addr)
} }
if intf != nil { if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(intf.Index)) args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
} }
routeCmd := uspfilter.GetSystem32Command("route") routeCmd := uspfilter.GetSystem32Command("route")
@ -185,37 +213,6 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
return nil return nil
} }
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
if nexthop.Zone() != "" && intf == nil {
zone, err := strconv.Atoi(nexthop.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
intf = &net.Interface{Index: zone}
nexthop.WithZone("")
}
return addRouteCmd(prefix, nexthop, intf)
}
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
args := []string{"delete", prefix.String()}
if nexthop.IsValid() {
nexthop.WithZone("")
args = append(args, nexthop.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func isCacheDisabled() bool { func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
} }

View File

@ -1,4 +1,4 @@
package routemanager package systemops
import ( import (
"context" "context"
@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
InterfaceIndex int `json:"InterfaceIndex"` InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"` InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"` AddressFamily int `json:"AddressFamily"`
NextHop string `json:"NextHop"` NextHop string `json:"Nexthop"`
DestinationPrefix string `json:"DestinationPrefix"` DestinationPrefix string `json:"DestinationPrefix"`
} }
@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
host, _, err := net.SplitHostPort(destination) host, _, err := net.SplitHostPort(destination)
require.NoError(t, err) require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output() out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute") require.NoError(t, err, "Failed to execute Find-NetRoute")
@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
} }
func fetchOriginalGateway() (*RouteInfo, error) { func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput() output, err := cmd.CombinedOutput()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)

View File

@ -1,33 +0,0 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@ -1,33 +0,0 @@
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
)
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func cleanupRouting() error {
return nil
}
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func removeVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}

View File

@ -1,24 +0,0 @@
//go:build !linux && !ios
package routemanager
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func enableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericAddVPNRoute(prefix, intf)
}
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return genericRemoveVPNRoute(prefix, intf)
}

View File

@ -0,0 +1,29 @@
package util
import (
"fmt"
"net"
"net/netip"
)
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return prefix, nil
}

View File

@ -0,0 +1,16 @@
package vars
import (
"errors"
"net/netip"
)
const MinRangeBits = 7
var (
ErrRouteNotFound = errors.New("route not found")
ErrRouteNotAllowed = errors.New("route not allowed")
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
)

View File

@ -3,11 +3,11 @@ package routeselector
import ( import (
"fmt" "fmt"
"slices" "slices"
"strings"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
route "github.com/netbirdio/netbird/route" route "github.com/netbirdio/netbird/route"
) )
@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
var multiErr *multierror.Error var err *multierror.Error
for _, route := range routes { for _, route := range routes {
if !slices.Contains(allRoutes, route) { if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route)) err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue continue
} }
@ -41,11 +41,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
} }
rs.selectAll = false rs.selectAll = false
if multiErr != nil { return errors.FormatErrorOrNil(err)
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
} }
// SelectAllRoutes sets the selector to select all routes. // SelectAllRoutes sets the selector to select all routes.
@ -65,21 +61,17 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
} }
} }
var multiErr *multierror.Error var err *multierror.Error
for _, route := range routes { for _, route := range routes {
if !slices.Contains(allRoutes, route) { if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route)) err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue continue
} }
delete(rs.selectedRoutes, route) delete(rs.selectedRoutes, route)
} }
if multiErr != nil { return errors.FormatErrorOrNil(err)
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
} }
// DeselectAllRoutes deselects all routes, effectively disabling route selection. // DeselectAllRoutes deselects all routes, effectively disabling route selection.
@ -111,18 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
} }
return filtered return filtered
} }
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
routes := route.HAMap{ routes := route.HAMap{
"route1-10.0.0.0/8": {}, "route1|10.0.0.0/8": {},
"route2-192.168.0.0/16": {}, "route2|192.168.0.0/16": {},
"route3-172.16.0.0/12": {}, "route3|172.16.0.0/12": {},
} }
filtered := rs.FilterSelected(routes) filtered := rs.FilterSelected(routes)
assert.Equal(t, route.HAMap{ assert.Equal(t, route.HAMap{
"route1-10.0.0.0/8": {}, "route1|10.0.0.0/8": {},
"route2-192.168.0.0/16": {}, "route2|192.168.0.0/16": {},
}, filtered) }, filtered)
} }

File diff suppressed because it is too large Load Diff

View File

@ -92,6 +92,8 @@ message LoginRequest {
repeated string extraIFaceBlacklist = 17; repeated string extraIFaceBlacklist = 17;
optional bool networkMonitor = 18; optional bool networkMonitor = 18;
optional google.protobuf.Duration dnsRouteInterval = 19;
} }
message LoginResponse { message LoginResponse {
@ -233,10 +235,17 @@ message SelectRoutesRequest {
message SelectRoutesResponse { message SelectRoutesResponse {
} }
message IPList {
repeated string ips = 1;
}
message Route { message Route {
string ID = 1; string ID = 1;
string network = 2; string network = 2;
bool selected = 3; bool selected = 3;
repeated string domains = 4;
map<string, IPList> resolvedIPs = 5;
} }
message DebugBundleRequest { message DebugBundleRequest {

View File

@ -9,17 +9,19 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type selectRoute struct { type selectRoute struct {
NetID route.NetID NetID route.NetID
Network netip.Prefix Network netip.Prefix
Domains domain.List
Selected bool Selected bool
} }
// ListRoutes returns a list of all available routes. // ListRoutes returns a list of all available routes.
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
@ -43,6 +45,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
route := &selectRoute{ route := &selectRoute{
NetID: id, NetID: id,
Network: rt[0].Network, Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id), Selected: routeSelector.IsSelected(id),
} }
routes = append(routes, route) routes = append(routes, route)
@ -63,13 +66,29 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
return iPrefix < jPrefix return iPrefix < jPrefix
}) })
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
var pbRoutes []*proto.Route var pbRoutes []*proto.Route
for _, route := range routes { for _, route := range routes {
pbRoutes = append(pbRoutes, &proto.Route{ pbRoute := &proto.Route{
ID: string(route.NetID), ID: string(route.NetID),
Network: route.Network.String(), Network: route.Network.String(),
Domains: route.Domains.ToSafeStringList(),
ResolvedIPs: map[string]*proto.IPList{},
Selected: route.Selected, Selected: route.Selected,
}) }
for _, domain := range route.Domains {
if prefixes, exists := resolvedDomains[domain]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
Ips: ipStrings,
}
}
}
pbRoutes = append(pbRoutes, pbRoute)
} }
return &proto.ListRoutesResponse{ return &proto.ListRoutesResponse{

View File

@ -365,6 +365,12 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
} }
if msg.DnsRouteInterval != nil {
duration := msg.DnsRouteInterval.AsDuration()
inputConfig.DNSRouteInterval = &duration
s.latestConfigInput.DNSRouteInterval = &duration
}
s.mutex.Unlock() s.mutex.Unlock()
if msg.OptionalPreSharedKey != nil { if msg.OptionalPreSharedKey != nil {

View File

@ -0,0 +1,10 @@
//go:build freebsd
package ssh
import (
"os"
)
func setWinSize(file *os.File, width, height int) {
}

View File

@ -8,6 +8,7 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@ -33,6 +34,12 @@ type Environment struct {
Platform string Platform string
} }
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// Info is an object that contains machine information // Info is an object that contains machine information
// Most of the code is taken from https://github.com/matishsiao/goInfo // Most of the code is taken from https://github.com/matishsiao/goInfo
type Info struct { type Info struct {
@ -51,6 +58,7 @@ type Info struct {
SystemProductName string SystemProductName string
SystemManufacturer string SystemManufacturer string
Environment Environment Environment Environment
Files []File
} }
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@ -132,3 +140,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
} }
return false return false
} }
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
processCheckPaths := make([]string, 0)
for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
if err != nil {
return nil, err
}
info := GetInfo(ctx)
info.Files = files
return info, nil
}

View File

@ -44,6 +44,11 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
func uname() []string { func uname() []string {
res := run("/system/bin/uname", "-a") res := run("/system/bin/uname", "-a")
return strings.Split(res, " ") return strings.Split(res, " ")

View File

@ -1,15 +1,18 @@
//go:build freebsd
package system package system
import ( import (
"bytes" "bytes"
"context" "context"
"fmt"
"os" "os"
"os/exec" "os/exec"
"runtime" "runtime"
"strings" "strings"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/system/detect_cloud" "github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@ -22,8 +25,8 @@ func GetInfo(ctx context.Context) *Info {
out = _getInfo() out = _getInfo()
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
osStr := strings.Replace(out, "\n", "", -1) osStr := strings.ReplaceAll(out, "\n", "")
osStr = strings.Replace(osStr, "\r\n", "", -1) osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ") osInfo := strings.Split(osStr, " ")
env := Environment{ env := Environment{
@ -31,14 +34,23 @@ func GetInfo(ctx context.Context) *Info {
Platform: detect_platform.Detect(ctx), Platform: detect_platform.Detect(ctx),
} }
gio := &Info{Kernel: osInfo[0], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: osInfo[1], Environment: env} osName, osVersion := readOsReleaseFile()
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
return gio return &Info{
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(),
WiretrusteeVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
Environment: env,
}
} }
func _getInfo() string { func _getInfo() string {
@ -50,7 +62,8 @@ func _getInfo() string {
cmd.Stderr = &stderr cmd.Stderr = &stderr
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
fmt.Println("getInfo:", err) log.Warnf("getInfo: %s", err)
} }
return out.String() return out.String()
} }

View File

@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
// extractOsVersion extracts operating system version from context or returns the default // extractOsVersion extracts operating system version from context or returns the default
func extractOsVersion(ctx context.Context, defaultName string) string { func extractOsVersion(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(OsVersionCtxKey).(string) v, ok := ctx.Value(OsVersionCtxKey).(string)

View File

@ -28,28 +28,11 @@ func GetInfo(ctx context.Context) *Info {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
releaseInfo := _getReleaseInfo()
for strings.Contains(info, "broken pipe") {
releaseInfo = _getReleaseInfo()
time.Sleep(500 * time.Millisecond)
}
osRelease := strings.Split(releaseInfo, "\n")
var osName string
var osVer string
for _, s := range osRelease {
if strings.HasPrefix(s, "NAME=") {
osName = strings.Split(s, "=")[1]
osName = strings.ReplaceAll(osName, "\"", "")
} else if strings.HasPrefix(s, "VERSION_ID=") {
osVer = strings.Split(s, "=")[1]
osVer = strings.ReplaceAll(osVer, "\"", "")
}
}
osStr := strings.ReplaceAll(info, "\n", "") osStr := strings.ReplaceAll(info, "\n", "")
osStr = strings.ReplaceAll(osStr, "\r\n", "") osStr = strings.ReplaceAll(osStr, "\r\n", "")
osInfo := strings.Split(osStr, " ") osInfo := strings.Split(osStr, " ")
osName, osVersion := readOsReleaseFile()
if osName == "" { if osName == "" {
osName = osInfo[3] osName = osInfo[3]
} }
@ -72,7 +55,7 @@ func GetInfo(ctx context.Context) *Info {
Kernel: osInfo[0], Kernel: osInfo[0],
Platform: osInfo[2], Platform: osInfo[2],
OS: osName, OS: osName,
OSVersion: osVer, OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname), Hostname: extractDeviceName(ctx, systemHostname),
GoOS: runtime.GOOS, GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
@ -103,20 +86,6 @@ func _getInfo() string {
return out.String() return out.String()
} }
func _getReleaseInfo() string {
cmd := exec.Command("cat", "/etc/os-release")
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Warnf("geucwReleaseInfo: %s", err)
}
return out.String()
}
func sysInfo() (serialNumber string, productName string, manufacturer string) { func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo var si sysinfo.SysInfo
si.GetSysInfo() si.GetSysInfo()

View File

@ -0,0 +1,38 @@
//go:build (linux && !android) || freebsd
package system
import (
"bufio"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
func readOsReleaseFile() (osName string, osVer string) {
file, err := os.Open("/etc/os-release")
if err != nil {
log.Warnf("failed to open file /etc/os-release: %s", err)
return "", ""
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "NAME=") {
osName = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
continue
}
if strings.HasPrefix(line, "VERSION_ID=") {
osVer = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
continue
}
if osName != "" && osVer != "" {
break
}
}
return
}

58
client/system/process.go Normal file
View File

@ -0,0 +1,58 @@
//go:build windows || (linux && !android) || (darwin && !ios)
package system
import (
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processes, err := process.Processes()
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, p := range processes {
path, _ := p.Exe()
if path != "" {
processMap[path] = true
}
}
uniqueProcesses := make([]string, 0, len(processMap))
for p := range processMap {
uniqueProcesses = append(uniqueProcesses, p)
}
return uniqueProcesses, nil
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
if err != nil {
return nil, err
}
for i, path := range paths {
file := File{Path: path}
_, err := os.Stat(path)
file.Exist = !os.IsNotExist(err)
file.ProcessIsRunning = slices.Contains(runningProcesses, path)
files[i] = file
}
return files, nil
}

View File

@ -1,4 +1,4 @@
//go:build !(linux && 386) //go:build !(linux && 386) && !freebsd
package main package main

View File

@ -20,7 +20,7 @@ import (
func (s *serviceClient) showRoutesUI() { func (s *serviceClient) showRoutesUI() {
s.wRoutes = s.app.NewWindow("NetBird Routes") s.wRoutes = s.app.NewWindow("NetBird Routes")
grid := container.New(layout.NewGridLayout(2)) grid := container.New(layout.NewGridLayout(3))
go s.updateRoutes(grid) go s.updateRoutes(grid)
routeCheckContainer := container.NewVBox() routeCheckContainer := container.NewVBox()
routeCheckContainer.Add(grid) routeCheckContainer.Add(grid)
@ -61,14 +61,16 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
grid.Objects = nil grid.Objects = nil
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
grid.Add(idHeader) grid.Add(idHeader)
grid.Add(networkHeader) grid.Add(networkHeader)
grid.Add(resolvedIPsHeader)
for _, route := range routes { for _, route := range routes {
r := route r := route
checkBox := widget.NewCheck(r.ID, func(checked bool) { checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
s.selectRoute(r.ID, checked) s.selectRoute(r.ID, checked)
}) })
checkBox.Checked = route.Selected checkBox.Checked = route.Selected
@ -76,10 +78,31 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
checkBox.Refresh() checkBox.Refresh()
grid.Add(checkBox) grid.Add(checkBox)
grid.Add(widget.NewLabel(r.Network)) network := r.GetNetwork()
domains := r.GetDomains()
if len(domains) > 0 {
network = strings.Join(domains, ", ")
}
grid.Add(widget.NewLabel(network))
if len(domains) > 0 {
var resolvedIPsList []string
for _, domain := range r.GetDomains() {
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
}
}
// TODO: limit width
resolvedIPsLabel := widget.NewLabel(strings.Join(resolvedIPsList, ", "))
grid.Add(resolvedIPsLabel)
} else {
grid.Add(widget.NewLabel(""))
}
} }
s.wRoutes.Content().Refresh() s.wRoutes.Content().Refresh()
grid.Refresh()
} }
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {

2
go.mod
View File

@ -68,6 +68,7 @@ require (
github.com/pion/turn/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_golang v1.19.1
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go v0.31.0
@ -176,7 +177,6 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.15.0 // indirect github.com/prometheus/procfs v0.15.0 // indirect
github.com/shirou/gopsutil/v3 v3.24.4 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect

View File

@ -23,24 +23,6 @@ func parseWGAddress(address string) (WGAddress, error) {
}, nil }, nil
} }
// Masked returns the WGAddress with the IP address part masked according to its network mask.
func (addr WGAddress) Masked() WGAddress {
ip := addr.IP.To4()
if ip == nil {
ip = addr.IP.To16()
}
maskedIP := make(net.IP, len(ip))
for i := range ip {
maskedIP[i] = ip[i] & addr.Network.Mask[i]
}
return WGAddress{
IP: maskedIP,
Network: addr.Network,
}
}
func (addr WGAddress) String() string { func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)

8
iface/freebsd/errors.go Normal file
View File

@ -0,0 +1,8 @@
package freebsd
import "errors"
var (
ErrDoesNotExist = errors.New("does not exist")
ErrNameDoesNotMatch = errors.New("name does not match")
)

108
iface/freebsd/iface.go Normal file
View File

@ -0,0 +1,108 @@
package freebsd
import (
"bufio"
"fmt"
"strconv"
"strings"
)
type iface struct {
Name string
MTU int
Group string
IPAddrs []string
}
func parseError(output []byte) error {
// TODO: implement without allocations
lines := string(output)
if strings.Contains(lines, "does not exist") {
return ErrDoesNotExist
}
return nil
}
func parseIfconfigOutput(output []byte) (*iface, error) {
// TODO: implement without allocations
lines := string(output)
scanner := bufio.NewScanner(strings.NewReader(lines))
var name, mtu, group string
var ips []string
for scanner.Scan() {
line := scanner.Text()
// If line contains ": flags", it's a line with interface information
if strings.Contains(line, ": flags") {
parts := strings.Fields(line)
if len(parts) < 4 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
name = strings.TrimSuffix(parts[0], ":")
if strings.Contains(line, "mtu") {
mtuIndex := 0
for i, part := range parts {
if part == "mtu" {
mtuIndex = i
break
}
}
mtu = parts[mtuIndex+1]
}
}
// If line contains "groups:", it's a line with interface group
if strings.Contains(line, "groups:") {
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
group = parts[1]
}
// If line contains "inet ", it's a line with IP address
if strings.Contains(line, "inet ") {
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
ips = append(ips, parts[1])
}
}
if name == "" {
return nil, fmt.Errorf("interface name not found in ifconfig output")
}
mtuInt, err := strconv.Atoi(mtu)
if err != nil {
return nil, fmt.Errorf("failed to parse MTU: %w", err)
}
return &iface{
Name: name,
MTU: mtuInt,
Group: group,
IPAddrs: ips,
}, nil
}
func parseIFName(output []byte) (string, error) {
// TODO: implement without allocations
lines := strings.Split(string(output), "\n")
if len(lines) == 0 || lines[0] == "" {
return "", fmt.Errorf("no output returned")
}
fields := strings.Fields(lines[0])
if len(fields) > 1 {
return "", fmt.Errorf("invalid output")
}
return fields[0], nil
}

View File

@ -0,0 +1,76 @@
package freebsd
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseIfconfigOutput(t *testing.T) {
testOutput := `wg1: flags=8080<NOARP,MULTICAST> metric 0 mtu 1420
options=80000<LINKSTATE>
groups: wg
nd6 options=109<PERFORMNUD,IFDISABLED,NO_DAD>`
expected := &iface{
Name: "wg1",
MTU: 1420,
Group: "wg",
}
result, err := parseIfconfigOutput(([]byte)(testOutput))
if err != nil {
t.Errorf("Error parsing ifconfig output: %v", err)
return
}
assert.Equal(t, expected.Name, result.Name, "Name should match")
assert.Equal(t, expected.MTU, result.MTU, "MTU should match")
assert.Equal(t, expected.Group, result.Group, "Group should match")
}
func TestParseIFName(t *testing.T) {
tests := []struct {
name string
output string
expected string
expectedErr error
}{
{
name: "ValidOutput",
output: "eth0\n",
expected: "eth0",
},
{
name: "ValidOutputOneLine",
output: "eth0",
expected: "eth0",
},
{
name: "EmptyOutput",
output: "",
expectedErr: fmt.Errorf("no output returned"),
},
{
name: "InvalidOutput",
output: "This is an invalid output\n",
expectedErr: fmt.Errorf("invalid output"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := parseIFName(([]byte)(test.output))
assert.Equal(t, test.expected, result, "Interface names should match")
if test.expectedErr != nil {
assert.NotNil(t, err, "Error should not be nil")
assert.EqualError(t, err, test.expectedErr.Error(), "Error messages should match")
} else {
assert.Nil(t, err, "Error should be nil")
}
})
}
}

239
iface/freebsd/link.go Normal file
View File

@ -0,0 +1,239 @@
package freebsd
import (
"bytes"
"errors"
"fmt"
"os/exec"
"strconv"
log "github.com/sirupsen/logrus"
)
const wgIFGroup = "wg"
// Link represents a network interface.
type Link struct {
name string
}
func NewLink(name string) *Link {
return &Link{
name: name,
}
}
// LinkByName retrieves a network interface by its name.
func LinkByName(name string) (*Link, error) {
out, err := exec.Command("ifconfig", name).CombinedOutput()
if err != nil {
if pErr := parseError(out); pErr != nil {
return nil, pErr
}
log.Debugf("ifconfig out: %s", out)
return nil, fmt.Errorf("command run: %w", err)
}
i, err := parseIfconfigOutput(out)
if err != nil {
return nil, fmt.Errorf("parse ifconfig output: %w", err)
}
if i.Name != name {
return nil, ErrNameDoesNotMatch
}
return &Link{name: i.Name}, nil
}
// Recreate - create new interface, remove current before create if it exists
func (l *Link) Recreate() error {
ok, err := l.isExist()
if err != nil {
return fmt.Errorf("is exist: %w", err)
}
if ok {
if err := l.del(l.name); err != nil {
return fmt.Errorf("del: %w", err)
}
}
return l.Add()
}
// Add creates a new network interface.
func (l *Link) Add() error {
parsedName, err := l.create(wgIFGroup)
if err != nil {
return fmt.Errorf("create link: %w", err)
}
if parsedName == l.name {
return nil
}
parsedName, err = l.rename(parsedName, l.name)
if err != nil {
errDel := l.del(parsedName)
if errDel != nil {
return fmt.Errorf("del on rename link: %w: %w", err, errDel)
}
return fmt.Errorf("rename link: %w", err)
}
return nil
}
// Del removes an existing network interface.
func (l *Link) Del() error {
return l.del(l.name)
}
// SetMTU sets the MTU of the network interface.
func (l *Link) SetMTU(mtu int) error {
return l.setMTU(mtu)
}
// AssignAddr assigns an IP address and netmask to the network interface.
func (l *Link) AssignAddr(ip, netmask string) error {
return l.setAddr(ip, netmask)
}
func (l *Link) Up() error {
return l.up(l.name)
}
func (l *Link) Down() error {
return l.down(l.name)
}
func (l *Link) isExist() (bool, error) {
_, err := LinkByName(l.name)
if errors.Is(err, ErrDoesNotExist) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("link by name: %w", err)
}
return true, nil
}
func (l *Link) create(groupName string) (string, error) {
cmd := exec.Command("ifconfig", groupName, "create")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("ifconfig out: %s", output)
return "", fmt.Errorf("create %s interface: %w", groupName, err)
}
interfaceName, err := parseIFName(output)
if err != nil {
return "", fmt.Errorf("parse interface name: %w", err)
}
return interfaceName, nil
}
func (l *Link) rename(oldName, newName string) (string, error) {
cmd := exec.Command("ifconfig", oldName, "name", newName)
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("ifconfig out: %s", output)
return "", fmt.Errorf("change name %q -> %q: %w", oldName, newName, err)
}
interfaceName, err := parseIFName(output)
if err != nil {
return "", fmt.Errorf("parse new name: %w", err)
}
return interfaceName, nil
}
func (l *Link) del(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "destroy")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("destroy %s interface: %w", name, err)
}
return nil
}
func (l *Link) setMTU(mtu int) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", l.name, "mtu", strconv.Itoa(mtu))
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("set interface mtu: %w", err)
}
return nil
}
func (l *Link) setAddr(ip, netmask string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", l.name, "inet", ip, "netmask", netmask)
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("set interface addr: %w", err)
}
return nil
}
func (l *Link) up(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "up")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("up %s interface: %w", name, err)
}
return nil
}
func (l *Link) down(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "down")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("down %s interface: %w", name, err)
}
return nil
}

View File

@ -48,6 +48,19 @@ func (w *WGIface) Address() WGAddress {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
// ToInterface returns the net.Interface for the Wireguard interface
func (r *WGIface) ToInterface() *net.Interface {
name := r.tun.DeviceName()
intf, err := net.InterfaceByName(name)
if err != nil {
log.Warnf("Failed to get interface by name %s: %v", name, err)
intf = &net.Interface{
Name: name,
}
}
return intf
}
// Up configures a Wireguard interface // Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before) // The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
@ -94,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.addAllowedIP(peerKey, allowedIP) return w.configurer.addAllowedIP(peerKey, allowedIP)
} }
@ -103,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.removeAllowedIP(peerKey, allowedIP) return w.configurer.removeAllowedIP(peerKey, allowedIP)
} }

View File

@ -1,5 +1,4 @@
//go:build !android //go:build !android
// +build !android
package iface package iface

View File

@ -1,5 +1,4 @@
//go:build !ios //go:build !ios
// +build !ios
package iface package iface

View File

@ -1,5 +1,4 @@
//go:build ios //go:build ios
// +build ios
package iface package iface

View File

@ -1,10 +1,10 @@
//go:build !android //go:build (linux && !android) || freebsd
// +build !android
package iface package iface
import ( import (
"fmt" "fmt"
"runtime"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
@ -43,5 +43,5 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform") return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
} }

View File

@ -1,5 +1,4 @@
//go:build !linux || android //go:build (!linux && !freebsd) || android
// +build !linux android
package iface package iface

18
iface/module_freebsd.go Normal file
View File

@ -0,0 +1,18 @@
package iface
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
// we are currently do not use it, since it is required to add wireguard kernel support to
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
// - https://github.com/mdlayher/socket
// TODO: implement kernel space
return false
}
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
}

View File

@ -1,5 +1,4 @@
//go:build linux || windows //go:build linux || windows || freebsd
// +build linux windows
package iface package iface

View File

@ -1,5 +1,4 @@
//go:build darwin //go:build darwin
// +build darwin
package iface package iface

View File

@ -1,4 +1,4 @@
//go:build linux && !android //go:build (linux && !android) || freebsd
package iface package iface
@ -6,11 +6,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
@ -32,6 +30,8 @@ type tunKernelDevice struct {
} }
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{ return &tunKernelDevice{
ctx: ctx, ctx: ctx,
@ -48,53 +48,29 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in
func (t *tunKernelDevice) Create() (wgConfigurer, error) { func (t *tunKernelDevice) Create() (wgConfigurer, error) {
link := newWGLink(t.name) link := newWGLink(t.name)
// check if interface exists if err := link.recreate(); err != nil {
l, err := netlink.LinkByName(t.name) return nil, fmt.Errorf("recreate: %w", err)
if err != nil {
switch err.(type) {
case netlink.LinkNotFoundError:
break
default:
return nil, err
}
}
// remove if interface exists
if l != nil {
err = netlink.LinkDel(link)
if err != nil {
return nil, err
}
}
log.Debugf("adding device: %s", t.name)
err = netlink.LinkAdd(link)
if os.IsExist(err) {
log.Infof("interface %s already exists. Will reuse.", t.name)
} else if err != nil {
return nil, err
} }
t.link = link t.link = link
err = t.assignAddr() if err := t.assignAddr(); err != nil {
if err != nil { return nil, fmt.Errorf("assign addr: %w", err)
return nil, err
} }
// todo do a discovery // TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
err = netlink.LinkSetMTU(link, t.mtu)
if err != nil { if err := link.setMTU(t.mtu); err != nil {
log.Errorf("error setting MTU on interface: %s", t.name) return nil, fmt.Errorf("set mtu: %w", err)
return nil, err
} }
configurer := newWGConfigurer(t.name) configurer := newWGConfigurer(t.name)
err = configurer.configureInterface(t.key, t.wgPort)
if err != nil { if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
return nil, err return nil, err
} }
return configurer, nil return configurer, nil
} }
@ -108,9 +84,10 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
} }
log.Debugf("bringing up interface: %s", t.name) log.Debugf("bringing up interface: %s", t.name)
err := netlink.LinkSetUp(t.link)
if err != nil { if err := t.link.up(); err != nil {
log.Errorf("error bringing up interface: %s", t.name) log.Errorf("error bringing up interface: %s", t.name)
return nil, err return nil, err
} }
@ -178,32 +155,5 @@ func (t *tunKernelDevice) Wrapper() *DeviceWrapper {
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *tunKernelDevice) assignAddr() error { func (t *tunKernelDevice) assignAddr() error {
link := newWGLink(t.name) return t.link.assignAddr(t.address)
//delete existing addresses
list, err := netlink.AddrList(link, 0)
if err != nil {
return err
}
if len(list) > 0 {
for _, a := range list {
addr := a
err = netlink.AddrDel(link, &addr)
if err != nil {
return err
}
}
}
log.Debugf("adding address %s to interface: %s", t.address.String(), t.name)
addr, _ := netlink.ParseAddr(t.address.String())
err = netlink.AddrAdd(link, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", t.name, t.address.String())
} else if err != nil {
return err
}
// On linux, the link must be brought up
err = netlink.LinkSetUp(link)
return err
} }

80
iface/tun_link_freebsd.go Normal file
View File

@ -0,0 +1,80 @@
package iface
import (
"fmt"
"github.com/netbirdio/netbird/iface/freebsd"
log "github.com/sirupsen/logrus"
)
type wgLink struct {
name string
link *freebsd.Link
}
func newWGLink(name string) *wgLink {
link := freebsd.NewLink(name)
return &wgLink{
name: name,
link: link,
}
}
// Type returns the interface type
func (l *wgLink) Type() string {
return "wireguard"
}
// Close deletes the link interface
func (l *wgLink) Close() error {
return l.link.Del()
}
func (l *wgLink) recreate() error {
if err := l.link.Recreate(); err != nil {
return fmt.Errorf("recreate: %w", err)
}
return nil
}
func (l *wgLink) setMTU(mtu int) error {
if err := l.link.SetMTU(mtu); err != nil {
return fmt.Errorf("set mtu: %w", err)
}
return nil
}
func (l *wgLink) up() error {
if err := l.link.Up(); err != nil {
return fmt.Errorf("up: %w", err)
}
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)
}
ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
err = link.AssignAddr(ip, mask)
if err != nil {
return fmt.Errorf("assign addr: %w", err)
}
err = link.Up()
if err != nil {
return fmt.Errorf("up: %w", err)
}
return nil
}

View File

@ -2,7 +2,13 @@
package iface package iface
import "github.com/vishvananda/netlink" import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
type wgLink struct { type wgLink struct {
attrs *netlink.LinkAttrs attrs *netlink.LinkAttrs
@ -31,3 +37,97 @@ func (l *wgLink) Type() string {
func (l *wgLink) Close() error { func (l *wgLink) Close() error {
return netlink.LinkDel(l) return netlink.LinkDel(l)
} }
func (l *wgLink) recreate() error {
name := l.attrs.Name
// check if interface exists
link, err := netlink.LinkByName(name)
if err != nil {
switch err.(type) {
case netlink.LinkNotFoundError:
break
default:
return fmt.Errorf("link by name: %w", err)
}
}
// remove if interface exists
if link != nil {
err = netlink.LinkDel(l)
if err != nil {
return err
}
}
log.Debugf("adding device: %s", name)
err = netlink.LinkAdd(l)
if os.IsExist(err) {
log.Infof("interface %s already exists. Will reuse.", name)
} else if err != nil {
return fmt.Errorf("link add: %w", err)
}
return nil
}
func (l *wgLink) setMTU(mtu int) error {
if err := netlink.LinkSetMTU(l, mtu); err != nil {
log.Errorf("error setting MTU on interface: %s", l.attrs.Name)
return fmt.Errorf("link set mtu: %w", err)
}
return nil
}
func (l *wgLink) up() error {
if err := netlink.LinkSetUp(l); err != nil {
log.Errorf("error bringing up interface: %s", l.attrs.Name)
return fmt.Errorf("link setup: %w", err)
}
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {
return fmt.Errorf("list addr: %w", err)
}
if len(list) > 0 {
for _, a := range list {
addr := a
err = netlink.AddrDel(l, &addr)
if err != nil {
return fmt.Errorf("del addr: %w", err)
}
}
}
name := l.attrs.Name
addrStr := address.String()
log.Debugf("adding address %s to interface: %s", addrStr, name)
addr, err := netlink.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parse addr: %w", err)
}
err = netlink.AddrAdd(l, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", name, addrStr)
} else if err != nil {
return fmt.Errorf("add addr: %w", err)
}
// On linux, the link must be brought up
if err := netlink.LinkSetUp(l); err != nil {
return fmt.Errorf("link setup: %w", err)
}
return nil
}

View File

@ -1,14 +1,14 @@
//go:build linux && !android //go:build (linux && !android) || freebsd
package iface package iface
import ( import (
"fmt" "fmt"
"os" "os"
"runtime"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@ -31,6 +31,9 @@ type tunUSPDevice struct {
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
checkUser()
return &tunUSPDevice{ return &tunUSPDevice{
name: name, name: name,
address: address, address: address,
@ -129,30 +132,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
func (t *tunUSPDevice) assignAddr() error { func (t *tunUSPDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)
//delete existing addresses return link.assignAddr(t.address)
list, err := netlink.AddrList(link, 0) }
if err != nil {
return err func checkUser() {
} if runtime.GOOS == "freebsd" {
if len(list) > 0 { euid := os.Geteuid()
for _, a := range list { if euid != 0 {
addr := a log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
err = netlink.AddrDel(link, &addr) }
if err != nil { }
return err
}
}
}
log.Debugf("adding address %s to interface: %s", t.address.String(), t.name)
addr, _ := netlink.ParseAddr(t.address.String())
err = netlink.AddrAdd(link, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", t.name, t.address.String())
} else if err != nil {
return err
}
// On linux, the link must be brought up
err = netlink.LinkSetUp(link)
return err
} }

Some files were not shown because too many files have changed in this diff Show More