mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-23 23:29:15 +01:00
Release 0.28.0 (#2092)
* compile client under freebsd (#1620) Compile netbird client under freebsd and now support netstack and userspace modes. Refactoring linux specific code to share same code with FreeBSD, move to *_unix.go files. Not implemented yet: Kernel mode not supported DNS probably does not work yet Routing also probably does not work yet SSH support did not tested yet Lack of test environment for freebsd (dedicated VM for github runners under FreeBSD required) Lack of tests for freebsd specific code info reporting need to review and also implement, for example OS reported as GENERIC instead of FreeBSD (lack of FreeBSD icon in management interface) Lack of proper client setup under FreeBSD Lack of FreeBSD port/package * Add DNS routes (#1943) Given domains are resolved periodically and resolved IPs are replaced with the new ones. Unless the flag keep_route is set to true, then only new ones are added. This option is helpful if there are long-running connections that might still point to old IP addresses from changed DNS records. * Add process posture check (#1693) Introduces a process posture check to validate the existence and active status of specific binaries on peer systems. The check ensures that files are present at specified paths, and that corresponding processes are running. This check supports Linux, Windows, and macOS systems. Co-authored-by: Evgenii <mail@skillcoder.com> Co-authored-by: Pascal Fischer <pascal@netbird.io> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
95299be52d
commit
4fec709bb1
8
.github/workflows/golang-test-linux.yml
vendored
8
.github/workflows/golang-test-linux.yml
vendored
@ -86,7 +86,10 @@ jobs:
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
@ -108,6 +111,9 @@ jobs:
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
|
@ -36,6 +36,7 @@ const (
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -68,7 +69,9 @@ var (
|
||||
autoConnectDisabled bool
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
rootCmd = &cobra.Command{
|
||||
dnsRouteInterval time.Duration
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
Short: "",
|
||||
Long: "",
|
||||
|
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@ -66,18 +67,60 @@ func routesList(cmd *cobra.Command, _ []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
selectedStatus := "Not Selected"
|
||||
if route.GetSelected() {
|
||||
selectedStatus = "Selected"
|
||||
}
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
printRoutes(cmd, resp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
||||
cmd.Println("Available Routes:")
|
||||
for _, route := range resp.Routes {
|
||||
printRoute(cmd, route)
|
||||
}
|
||||
}
|
||||
|
||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
||||
selectedStatus := getSelectedStatus(route)
|
||||
domains := route.GetDomains()
|
||||
|
||||
if len(domains) > 0 {
|
||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||
} else {
|
||||
printNetworkRoute(cmd, route, selectedStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func getSelectedStatus(route *proto.Route) string {
|
||||
if route.GetSelected() {
|
||||
return "Selected"
|
||||
}
|
||||
return "Not Selected"
|
||||
}
|
||||
|
||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||
resolvedIPs := route.GetResolvedIPs()
|
||||
|
||||
if len(resolvedIPs) > 0 {
|
||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||
} else {
|
||||
cmd.Printf(" Resolved IPs: -\n")
|
||||
}
|
||||
}
|
||||
|
||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
||||
}
|
||||
|
||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
||||
cmd.Printf(" Resolved IPs:\n")
|
||||
for _, domain := range domains {
|
||||
if ipList, exists := resolvedIPs[domain]; exists {
|
||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
@ -807,11 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
peer.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
peer.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
}
|
||||
|
||||
@ -847,12 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
overview.Routes[i] = fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
overview.Routes[i] = anonymizeRoute(a, route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
|
||||
prefix, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
ip := a.AnonymizeIPString(prefix.Addr().String())
|
||||
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
|
||||
}
|
||||
domains := strings.Split(route, ", ")
|
||||
for i, domain := range domains {
|
||||
domains[i] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
return strings.Join(domains, ", ")
|
||||
}
|
||||
|
@ -7,11 +7,13 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@ -42,6 +44,7 @@ func init() {
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", false, "Enable network monitoring")
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@ -137,6 +140,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
ic.DNSRouteInterval = &dnsRouteInterval
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
@ -237,6 +244,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
loginRequest.NetworkMonitor = &networkMonitor
|
||||
}
|
||||
|
||||
if cmd.Flag(dnsRouteIntervalFlag).Changed {
|
||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
30
client/errors/errors.go
Normal file
30
client/errors/errors.go
Normal file
@ -0,0 +1,30 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 0 {
|
||||
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
||||
func FormatErrorOrNil(err *multierror.Error) error {
|
||||
if err != nil {
|
||||
err.ErrorFormat = formatError
|
||||
}
|
||||
return err.ErrorOrNil()
|
||||
}
|
@ -7,12 +7,14 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
@ -53,6 +55,7 @@ type ConfigInput struct {
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@ -95,6 +98,9 @@ type Config struct {
|
||||
// DisableAutoConnect determines whether the client should not start with the service
|
||||
// it's set to false by default due to backwards compatibility
|
||||
DisableAutoConnect bool
|
||||
|
||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
@ -357,6 +363,18 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
config.DNSRouteInterval = *input.DNSRouteInterval
|
||||
updated = true
|
||||
} else if config.DNSRouteInterval == 0 {
|
||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||
updated = true
|
||||
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
@ -252,8 +252,10 @@ func (c *ConnectClient) run(
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
err = c.engine.Start()
|
||||
@ -321,6 +323,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
|
6
client/internal/dns/consts_freebsd.go
Normal file
6
client/internal/dns/consts_freebsd.go
Normal file
@ -0,0 +1,6 @@
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
|
||||
)
|
8
client/internal/dns/consts_linux.go
Normal file
8
client/internal/dns/consts_linux.go
Normal file
@ -0,0 +1,8 @@
|
||||
//go:build !android
|
||||
|
||||
package dns
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@ -108,7 +108,7 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
@ -116,16 +116,10 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
var value string
|
||||
err = getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value)
|
||||
if err == nil {
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return systemdManager, nil
|
||||
}
|
||||
}
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
@ -39,6 +39,10 @@ func (w *mocWGIface) Address() iface.WGAddress {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) ToInterface() *net.Interface {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
20
client/internal/dns/systemd_freebsd.go
Normal file
20
client/internal/dns/systemd_freebsd.go
Normal file
@ -0,0 +1,20 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var errNotImplemented = errors.New("not implemented")
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
|
||||
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
return false
|
||||
}
|
@ -242,3 +242,25 @@ func getSystemdDbusProperty(property string, store any) error {
|
||||
|
||||
return v.Store(store)
|
||||
}
|
||||
|
||||
func isSystemdResolvedRunning() bool {
|
||||
return isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode)
|
||||
}
|
||||
|
||||
func isSystemdResolveConfMode() bool {
|
||||
if !isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {
|
||||
return false
|
||||
}
|
||||
|
||||
var value string
|
||||
if err := getSystemdDbusProperty(systemdDbusResolvConfModeProperty, &value); err != nil {
|
||||
log.Errorf("got an error while checking systemd resolv conf mode, error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if value == systemdDbusResolvConfModeForeign {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
@ -14,11 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
|
||||
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
|
||||
)
|
||||
|
||||
func CheckUncleanShutdown(wgIface string) error {
|
||||
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
@ -2,12 +2,17 @@
|
||||
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
ToInterface() *net.Interface
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -30,10 +31,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
mgm "github.com/netbirdio/netbird/management/client"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
@ -89,6 +92,8 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@ -154,6 +159,9 @@ type Engine struct {
|
||||
wgProbe *Probe
|
||||
|
||||
wgConnWorker sync.WaitGroup
|
||||
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@ -171,6 +179,7 @@ func NewEngine(
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
return NewEngineWithProbes(
|
||||
clientCtx,
|
||||
@ -184,6 +193,7 @@ func NewEngine(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
checks,
|
||||
)
|
||||
}
|
||||
|
||||
@ -200,6 +210,7 @@ func NewEngineWithProbes(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
|
||||
return &Engine{
|
||||
@ -220,6 +231,7 @@ func NewEngineWithProbes(
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
wgProbe: wgProbe,
|
||||
checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,7 +313,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
@ -527,6 +539,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if update.GetNetworkMap() != nil {
|
||||
// only apply new changes and ignore old ones
|
||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||
@ -534,7 +550,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
// if checks are equal, we skip the update
|
||||
if isChecksEqual(e.checks, checks) {
|
||||
return nil
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -550,8 +586,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
} else {
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on Windows is not supported")
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
@ -624,7 +660,14 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
go func() {
|
||||
err := e.mgmClient.Sync(e.ctx, e.handleSync)
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
|
||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
@ -772,15 +815,24 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, protoRoute := range protoRoutes {
|
||||
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
|
||||
var prefix netip.Prefix
|
||||
if len(protoRoute.Domains) == 0 {
|
||||
var err error
|
||||
if prefix, err = netip.ParsePrefix(protoRoute.Network); err != nil {
|
||||
log.Errorf("Failed to parse prefix %s: %v", protoRoute.Network, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
convertedRoute := &route.Route{
|
||||
ID: route.ID(protoRoute.ID),
|
||||
Network: prefix,
|
||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||
NetID: route.NetID(protoRoute.NetID),
|
||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||
Peer: protoRoute.Peer,
|
||||
Metric: int(protoRoute.Metric),
|
||||
Masquerade: protoRoute.Masquerade,
|
||||
KeepRoute: protoRoute.KeepRoute,
|
||||
}
|
||||
routes = append(routes, convertedRoute)
|
||||
}
|
||||
@ -1204,7 +1256,8 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
netMap, err := e.mgmClient.GetNetworkMap()
|
||||
info := system.GetInfo(e.ctx)
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -1430,3 +1483,10 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@ -212,7 +212,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -221,7 +221,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@ -394,7 +394,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
// feed updates to Engine via mocked Management client
|
||||
updates := make(chan *mgmtProto.SyncResponse)
|
||||
defer close(updates)
|
||||
syncFunc := func(ctx context.Context, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
|
||||
for msg := range updates {
|
||||
err := msgHandler(msg)
|
||||
if err != nil {
|
||||
@ -409,7 +409,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@ -568,7 +568,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
@ -738,7 +738,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
@ -1009,7 +1009,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
@ -5,8 +5,6 @@ package networkmonitor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
@ -14,10 +12,10 @@ import (
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||
@ -58,7 +56,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
if msg.Flags&unix.IFF_UP != 0 {
|
||||
continue
|
||||
}
|
||||
if (intfv4 == nil || ifinfo.Index != intfv4.Index) && (intfv6 == nil || ifinfo.Index != intfv6.Index) {
|
||||
if (nexthopv4.Intf == nil || ifinfo.Index != nexthopv4.Intf.Index) && (nexthopv6.Intf == nil || ifinfo.Index != nexthopv6.Intf.Index) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -86,7 +84,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
case unix.RTM_DELETE:
|
||||
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
go callback()
|
||||
}
|
||||
@ -114,7 +112,7 @@ func parseInterfaceMessage(buf []byte) (*route.InterfaceMessage, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
@ -129,5 +127,5 @@ func parseRouteMessage(buf []byte) (*routemanager.Route, error) {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return routemanager.MsgToRoute(msg)
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
@ -6,14 +6,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
@ -29,23 +28,22 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 netip.Addr
|
||||
var intf4, intf6 *net.Interface
|
||||
var nexthop4, nexthop6 systemops.Nexthop
|
||||
|
||||
operation := func() error {
|
||||
var errv4, errv6 error
|
||||
nexthop4, intf4, errv4 = routemanager.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, intf6, errv6 = routemanager.GetNextHop(netip.IPv6Unspecified())
|
||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||
|
||||
if errv4 != nil && errv6 != nil {
|
||||
return errors.New("failed to get default next hops")
|
||||
}
|
||||
|
||||
if errv4 == nil {
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4, intf4.Name)
|
||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||
}
|
||||
if errv6 == nil {
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6, intf6.Name)
|
||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||
}
|
||||
|
||||
// continue if either route was found
|
||||
@ -65,7 +63,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
|
||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
|
@ -6,16 +6,16 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthop6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
if intfv4 == nil && intfv6 == nil {
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||
return errors.New("no interfaces available")
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
|
||||
// handle interface state changes
|
||||
case update := <-linkChan:
|
||||
if (intfv4 == nil || update.Index != int32(intfv4.Index)) && (intfv6 == nil || update.Index != int32(intfv6.Index)) {
|
||||
if (nexthopv4.Intf == nil || update.Index != int32(nexthopv4.Intf.Index)) && (nexthopv6.Intf == nil || update.Index != int32(nexthopv6.Intf.Index)) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
|
||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
go callback()
|
||||
return nil
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -25,18 +25,18 @@ const (
|
||||
|
||||
const interval = 10 * time.Second
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interface, nexthopv6 netip.Addr, intfv6 *net.Interface, callback func()) error {
|
||||
var neighborv4, neighborv6 *routemanager.Neighbor
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
||||
var neighborv4, neighborv6 *systemops.Neighbor
|
||||
{
|
||||
initialNeighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
if n, ok := initialNeighbors[nexthopv4]; ok {
|
||||
if n, ok := initialNeighbors[nexthopv4.IP]; ok {
|
||||
neighborv4 = &n
|
||||
}
|
||||
if n, ok := initialNeighbors[nexthopv6]; ok {
|
||||
if n, ok := initialNeighbors[nexthopv6.IP]; ok {
|
||||
neighborv6 = &n
|
||||
}
|
||||
}
|
||||
@ -50,7 +50,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
case <-ctx.Done():
|
||||
return ErrStopped
|
||||
case <-ticker.C:
|
||||
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
|
||||
if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
@ -59,12 +59,10 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
}
|
||||
|
||||
func changed(
|
||||
nexthopv4 netip.Addr,
|
||||
intfv4 *net.Interface,
|
||||
neighborv4 *routemanager.Neighbor,
|
||||
nexthopv6 netip.Addr,
|
||||
intfv6 *net.Interface,
|
||||
neighborv6 *routemanager.Neighbor,
|
||||
nexthopv4 systemops.Nexthop,
|
||||
neighborv4 *systemops.Neighbor,
|
||||
nexthopv6 systemops.Nexthop,
|
||||
neighborv6 *systemops.Neighbor,
|
||||
) bool {
|
||||
neighbors, err := getNeighbors()
|
||||
if err != nil {
|
||||
@ -81,7 +79,7 @@ func changed(
|
||||
return false
|
||||
}
|
||||
|
||||
if routeChanged(nexthopv4, intfv4, routes) || routeChanged(nexthopv6, intfv6, routes) {
|
||||
if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -89,20 +87,20 @@ func changed(
|
||||
}
|
||||
|
||||
// routeChanged checks if the default routes still point to our nexthop/interface
|
||||
func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Prefix]routemanager.Route) bool {
|
||||
if !nexthop.IsValid() {
|
||||
func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes map[netip.Prefix]systemops.Route) bool {
|
||||
if !nexthop.IP.IsValid() {
|
||||
return false
|
||||
}
|
||||
|
||||
var unspec netip.Prefix
|
||||
if nexthop.Is6() {
|
||||
if nexthop.IP.Is6() {
|
||||
unspec = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
} else {
|
||||
unspec = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
|
||||
if r, ok := routes[unspec]; ok {
|
||||
if r.Nexthop != nexthop || compareIntf(r.Interface, intf) != 0 {
|
||||
if r.Nexthop != nexthop.IP || compareIntf(r.Interface, intf) != 0 {
|
||||
intf := "<nil>"
|
||||
if r.Interface != nil {
|
||||
intf = r.Interface.Name
|
||||
@ -119,13 +117,13 @@ func routeChanged(nexthop netip.Addr, intf *net.Interface, routes map[netip.Pref
|
||||
|
||||
}
|
||||
|
||||
func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighbors map[netip.Addr]routemanager.Neighbor) bool {
|
||||
func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
|
||||
if neighbor == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: consider non-local nexthops, e.g. on point-to-point interfaces
|
||||
if n, ok := neighbors[nexthop]; ok {
|
||||
if n, ok := neighbors[nexthop.IP]; ok {
|
||||
if n.State != reachable && n.State != permanent {
|
||||
log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
|
||||
return true
|
||||
@ -150,13 +148,13 @@ func neighborChanged(nexthop netip.Addr, neighbor *routemanager.Neighbor, neighb
|
||||
return false
|
||||
}
|
||||
|
||||
func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||
entries, err := routemanager.GetNeighbors()
|
||||
func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
|
||||
entries, err := systemops.GetNeighbors()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get neighbors: %w", err)
|
||||
}
|
||||
|
||||
neighbours := make(map[netip.Addr]routemanager.Neighbor, len(entries))
|
||||
neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
|
||||
for _, entry := range entries {
|
||||
neighbours[entry.IPAddress] = entry
|
||||
}
|
||||
@ -164,13 +162,13 @@ func getNeighbors() (map[netip.Addr]routemanager.Neighbor, error) {
|
||||
return neighbours, nil
|
||||
}
|
||||
|
||||
func getRoutes() (map[netip.Prefix]routemanager.Route, error) {
|
||||
entries, err := routemanager.GetRoutes()
|
||||
func getRoutes() (map[netip.Prefix]systemops.Route, error) {
|
||||
entries, err := systemops.GetRoutes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
routes := make(map[netip.Prefix]routemanager.Route, len(entries))
|
||||
routes := make(map[netip.Prefix]systemops.Route, len(entries))
|
||||
for _, entry := range entries {
|
||||
routes[entry.Destination] = entry
|
||||
}
|
||||
|
@ -2,14 +2,17 @@ package peer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// State contains the latest state of a peer
|
||||
@ -37,25 +40,25 @@ type State struct {
|
||||
// AddRoute add a single route to routes map
|
||||
func (s *State) AddRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
if s.routes == nil {
|
||||
s.routes = make(map[string]struct{})
|
||||
}
|
||||
s.routes[network] = struct{}{}
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// SetRoutes set state routes
|
||||
func (s *State) SetRoutes(routes map[string]struct{}) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
s.routes = routes
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// DeleteRoute removes a route from the network amp
|
||||
func (s *State) DeleteRoute(network string) {
|
||||
s.Mux.Lock()
|
||||
defer s.Mux.Unlock()
|
||||
delete(s.routes, network)
|
||||
s.Mux.Unlock()
|
||||
}
|
||||
|
||||
// GetRoutes return routes map
|
||||
@ -117,22 +120,23 @@ type FullStatus struct {
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
mux sync.Mutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]chan struct{}
|
||||
signalState bool
|
||||
signalError error
|
||||
managementState bool
|
||||
managementError error
|
||||
relayStates []relay.ProbeResult
|
||||
localPeer LocalPeerState
|
||||
offlinePeers []State
|
||||
mgmAddress string
|
||||
signalAddress string
|
||||
notifier *notifier
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
nsGroupStates []NSGroupState
|
||||
resolvedDomainsStates map[domain.Domain][]netip.Prefix
|
||||
|
||||
// To reduce the number of notification invocation this bool will be true when need to call the notification
|
||||
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
|
||||
@ -143,11 +147,12 @@ type Status struct {
|
||||
// NewRecorder returns a new Status instance
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
peers: make(map[string]State),
|
||||
changeNotify: make(map[string]chan struct{}),
|
||||
offlinePeers: make([]State, 0),
|
||||
notifier: newNotifier(),
|
||||
mgmAddress: mgmAddress,
|
||||
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,7 +193,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return State{}, errors.New("peer not found")
|
||||
return State{}, iface.ErrPeerNotFound
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
@ -429,6 +434,18 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.resolvedDomainsStates[domain] = prefixes
|
||||
}
|
||||
|
||||
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
delete(d.resolvedDomainsStates, domain)
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
@ -493,6 +510,12 @@ func (d *Status) GetDNSStates() []NSGroupState {
|
||||
return d.nsGroupStates
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
// GetFullStatus gets full status
|
||||
func (d *Status) GetFullStatus() FullStatus {
|
||||
d.mux.Lock()
|
||||
|
@ -3,19 +3,20 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const minRangeBits = 7
|
||||
|
||||
type routerPeerStatus struct {
|
||||
connected bool
|
||||
relayed bool
|
||||
@ -28,33 +29,42 @@ type routesUpdate struct {
|
||||
routes []*route.Route
|
||||
}
|
||||
|
||||
// RouteHandler defines the interface for handling routes
|
||||
type RouteHandler interface {
|
||||
String() string
|
||||
AddRoute(ctx context.Context) error
|
||||
RemoveRoute() error
|
||||
AddAllowedIPs(peerKey string) error
|
||||
RemoveAllowedIPs() error
|
||||
}
|
||||
|
||||
type clientNetwork struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
routes map[route.ID]*route.Route
|
||||
routeUpdate chan routesUpdate
|
||||
peerStateUpdate chan struct{}
|
||||
routePeersNotifiers map[string]chan struct{}
|
||||
chosenRoute *route.Route
|
||||
network netip.Prefix
|
||||
currentChosen *route.Route
|
||||
handler RouteHandler
|
||||
updateSerial uint64
|
||||
}
|
||||
|
||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
client := &clientNetwork{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
routePeersNotifiers: make(map[string]chan struct{}),
|
||||
routeUpdate: make(chan routesUpdate),
|
||||
peerStateUpdate: make(chan struct{}),
|
||||
network: network,
|
||||
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
|
||||
}
|
||||
return client
|
||||
}
|
||||
@ -86,8 +96,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
|
||||
// * Metric: Routes with lower metrics (better) are prioritized.
|
||||
// * Non-relayed: Routes without relays are preferred.
|
||||
// * Direct connections: Routes with direct peer connections are favored.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
// * Latency: Routes with lower latency are prioritized.
|
||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
||||
//
|
||||
// It returns the ID of the selected optimal route.
|
||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
|
||||
@ -96,8 +106,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
currScore := float64(0)
|
||||
|
||||
currID := route.ID("")
|
||||
if c.chosenRoute != nil {
|
||||
currID = c.chosenRoute.ID
|
||||
if c.currentChosen != nil {
|
||||
currID = c.currentChosen.ID
|
||||
}
|
||||
|
||||
for _, r := range c.routes {
|
||||
@ -151,18 +161,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
peers = append(peers, r.Peer)
|
||||
}
|
||||
|
||||
log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers)
|
||||
log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
|
||||
case chosen != currID:
|
||||
// we compare the current score + 10ms to the chosen score to avoid flapping between routes
|
||||
if currScore != 0 && currScore+0.01 > chosenScore {
|
||||
log.Debugf("keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
|
||||
return currID
|
||||
}
|
||||
var p string
|
||||
if rt := c.routes[chosen]; rt != nil {
|
||||
p = rt.Peer
|
||||
}
|
||||
log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, p, chosenScore, c.network)
|
||||
log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
|
||||
}
|
||||
|
||||
return chosen
|
||||
@ -196,98 +206,103 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer state: %v", err)
|
||||
}
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
|
||||
state.DeleteRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
|
||||
if state.ConnStatus != peer.StatusConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
if c.chosenRoute != nil {
|
||||
if err := removeVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
||||
}
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route: %v", err)
|
||||
}
|
||||
if c.currentChosen == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||
|
||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||
|
||||
// If no route is chosen, remove the route from the peer and system
|
||||
if chosen == "" {
|
||||
if newChosenID == "" {
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
return fmt.Errorf("remove route from peer and system: %v", err)
|
||||
return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
c.chosenRoute = nil
|
||||
c.currentChosen = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the chosen route is the same as the current route, do nothing
|
||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||
return nil
|
||||
}
|
||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
||||
c.currentChosen.IsEqual(c.routes[newChosenID]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.chosenRoute != nil {
|
||||
// If a previous route exists, remove it from the peer
|
||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
||||
return fmt.Errorf("remove route from peer: %v", err)
|
||||
if c.currentChosen == nil {
|
||||
// If they were not previously assigned to another peer, add routes to the system first
|
||||
if err := c.handler.AddRoute(c.ctx); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
} else {
|
||||
// otherwise add the route to the system
|
||||
if err := addVPNRoute(c.network, c.getAsInterface()); err != nil {
|
||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
|
||||
c.chosenRoute = c.routes[chosen]
|
||||
c.currentChosen = c.routes[newChosenID]
|
||||
|
||||
state, err := c.statusRecorder.GetPeer(c.chosenRoute.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
} else {
|
||||
state.AddRoute(c.network.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
|
||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||
c.network, c.chosenRoute.Peer, err)
|
||||
}
|
||||
c.addStateRoute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
@ -318,24 +333,23 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
log.Debugf("stopping watcher for network %s", c.network)
|
||||
err := c.removeRouteFromPeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
|
||||
log.Debugf("Stopping watcher for network [%v]", c.handler)
|
||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
||||
log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
|
||||
}
|
||||
return
|
||||
case <-c.peerStateUpdate:
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
case update := <-c.routeUpdate:
|
||||
if update.updateSerial < c.updateSerial {
|
||||
log.Warnf("Received a routes update with smaller serial number, ignoring it")
|
||||
log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Received a new client network route update for %s", c.network)
|
||||
log.Debugf("Received a new client network route update for [%v]", c.handler)
|
||||
|
||||
c.handleUpdate(update)
|
||||
|
||||
@ -343,7 +357,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
|
||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||
if err != nil {
|
||||
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
|
||||
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
|
||||
}
|
||||
|
||||
c.startPeersStatusChangeWatcher()
|
||||
@ -351,14 +365,9 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) getAsInterface() *net.Interface {
|
||||
intf, err := net.InterfaceByName(c.wgInterface.Name())
|
||||
if err != nil {
|
||||
log.Warnf("Couldn't get interface by name %s: %v", c.wgInterface.Name(), err)
|
||||
intf = &net.Interface{
|
||||
Name: c.wgInterface.Name(),
|
||||
}
|
||||
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
|
||||
if rt.IsDynamic() {
|
||||
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
|
||||
}
|
||||
|
||||
return intf
|
||||
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/static"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -340,9 +341,9 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
|
||||
// create new clientNetwork
|
||||
client := &clientNetwork{
|
||||
network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
routes: tc.existingRoutes,
|
||||
chosenRoute: currentRoute,
|
||||
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil),
|
||||
routes: tc.existingRoutes,
|
||||
currentChosen: currentRoute,
|
||||
}
|
||||
|
||||
chosenRoute := client.getBestRouteFromStatuses(tc.statuses)
|
||||
|
361
client/internal/routemanager/dynamic/route.go
Normal file
361
client/internal/routemanager/dynamic/route.go
Normal file
@ -0,0 +1,361 @@
|
||||
package dynamic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultInterval = time.Minute
|
||||
|
||||
minInterval = 2 * time.Second
|
||||
|
||||
addAllowedIP = "add allowed IP %s: %w"
|
||||
)
|
||||
|
||||
type domainMap map[domain.Domain][]netip.Prefix
|
||||
|
||||
type resolveResult struct {
|
||||
domain domain.Domain
|
||||
prefix netip.Prefix
|
||||
err error
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
interval time.Duration
|
||||
dynamicDomains domainMap
|
||||
mu sync.Mutex
|
||||
currentPeerKey string
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func NewRoute(
|
||||
rt *route.Route,
|
||||
routeRefCounter *refcounter.RouteRefCounter,
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
|
||||
interval time.Duration,
|
||||
statusRecorder *peer.Status,
|
||||
) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
interval: interval,
|
||||
dynamicDomains: domainMap{},
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) String() string {
|
||||
s, err := r.route.Domains.String()
|
||||
if err != nil {
|
||||
return r.route.Domains.PunycodeString()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(ctx context.Context) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
ctx, r.cancel = context.WithCancel(ctx)
|
||||
|
||||
go r.startResolver(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute will stop the dynamic resolver and remove all dynamic routes.
|
||||
// It doesn't touch allowed IPs, these should be removed separately and before calling this method.
|
||||
func (r *Route) RemoveRoute() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, prefixes := range r.dynamicDomains {
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
|
||||
|
||||
r.statusRecorder.DeleteResolvedDomainsStates(domain)
|
||||
}
|
||||
|
||||
r.dynamicDomains = domainMap{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for domain, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if err := r.incrementAllowedIP(domain, prefix, peerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
r.currentPeerKey = peerKey
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, domainPrefixes := range r.dynamicDomains {
|
||||
for _, prefix := range domainPrefixes {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.currentPeerKey = ""
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) startResolver(ctx context.Context) {
|
||||
log.Debugf("Starting dynamic route resolver for domains [%v]", r)
|
||||
|
||||
interval := r.interval
|
||||
if interval < minInterval {
|
||||
interval = minInterval
|
||||
log.Warnf("Dynamic route resolver interval %s is too low, setting to minimum value %s", r.interval, minInterval)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.update(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Stopping dynamic route resolver for domains [%v]", r)
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.update(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) update(ctx context.Context) {
|
||||
if resolved, err := r.resolveDomains(); err != nil {
|
||||
log.Errorf("Failed to resolve domains for route [%v]: %v", r, err)
|
||||
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
|
||||
log.Errorf("Failed to update dynamic routes for [%v]: %v", r, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Route) resolveDomains() (domainMap, error) {
|
||||
results := make(chan resolveResult)
|
||||
go r.resolve(results)
|
||||
|
||||
resolved := domainMap{}
|
||||
var merr *multierror.Error
|
||||
|
||||
for result := range results {
|
||||
if result.err != nil {
|
||||
merr = multierror.Append(merr, result.err)
|
||||
} else {
|
||||
resolved[result.domain] = append(resolved[result.domain], result.prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return resolved, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) resolve(results chan resolveResult) {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, d := range r.route.Domains {
|
||||
wg.Add(1)
|
||||
go func(domain domain.Domain) {
|
||||
defer wg.Done()
|
||||
ips, err := net.LookupIP(string(domain))
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
|
||||
return
|
||||
}
|
||||
for _, ip := range ips {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
results <- resolveResult{domain: domain, err: fmt.Errorf("get prefix from IP %s: %w", ip.String(), err)}
|
||||
return
|
||||
}
|
||||
results <- resolveResult{domain: domain, prefix: prefix}
|
||||
}
|
||||
}(d)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}
|
||||
|
||||
func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for domain, newPrefixes := range newDomains {
|
||||
oldPrefixes := r.dynamicDomains[domain]
|
||||
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
|
||||
|
||||
addedPrefixes, err := r.addRoutes(domain, toAdd)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(addedPrefixes) > 0 {
|
||||
log.Debugf("Added dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", addedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
removedPrefixes, err := r.removeRoutes(toRemove)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
} else if len(removedPrefixes) > 0 {
|
||||
log.Debugf("Removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", removedPrefixes), " ", ", "))
|
||||
}
|
||||
|
||||
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
|
||||
r.dynamicDomains[domain] = updatedPrefixes
|
||||
|
||||
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
var addedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if err := r.incrementAllowedIP(domain, prefix, r.currentPeerKey); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(addAllowedIP, prefix, err))
|
||||
}
|
||||
}
|
||||
addedPrefixes = append(addedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return addedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) removeRoutes(prefixes []netip.Prefix) ([]netip.Prefix, error) {
|
||||
if r.route.KeepRoute {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var removedPrefixes []netip.Prefix
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
if _, err := r.routeRefCounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %w", prefix, err))
|
||||
}
|
||||
if r.currentPeerKey != "" {
|
||||
if _, err := r.allowedIPsRefcounter.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
removedPrefixes = append(removedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return removedPrefixes, merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (r *Route) incrementAllowedIP(domain domain.Domain, prefix netip.Prefix, peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
|
||||
return fmt.Errorf(addAllowedIP, prefix, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
prefix.Addr(),
|
||||
domain.SafeString(),
|
||||
ref.Out,
|
||||
)
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
|
||||
prefixSet := make(map[netip.Prefix]bool)
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = false
|
||||
}
|
||||
for _, prefix := range newPrefixes {
|
||||
if _, exists := prefixSet[prefix]; exists {
|
||||
prefixSet[prefix] = true
|
||||
} else {
|
||||
toAdd = append(toAdd, prefix)
|
||||
}
|
||||
}
|
||||
for prefix, inUse := range prefixSet {
|
||||
if !inUse {
|
||||
toRemove = append(toRemove, prefix)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes []netip.Prefix) []netip.Prefix {
|
||||
prefixSet := make(map[netip.Prefix]struct{})
|
||||
for _, prefix := range oldPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
for _, prefix := range removedPrefixes {
|
||||
delete(prefixSet, prefix)
|
||||
}
|
||||
for _, prefix := range addedPrefixes {
|
||||
prefixSet[prefix] = struct{}{}
|
||||
}
|
||||
|
||||
var combinedPrefixes []netip.Prefix
|
||||
for prefix := range prefixSet {
|
||||
combinedPrefixes = append(combinedPrefixes, prefix)
|
||||
}
|
||||
|
||||
return combinedPrefixes
|
||||
}
|
@ -2,18 +2,23 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@ -21,11 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
|
||||
// nolint:unused
|
||||
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||
@ -40,31 +40,71 @@ type Manager interface {
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
notifier *notifier
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||
func NewManager(
|
||||
ctx context.Context,
|
||||
pubKey string,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface *iface.WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
initialRoutes []*route.Route,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
sysOps := systemops.NewSysOps(wgInterface)
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: newNotifier(),
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (any, error) {
|
||||
return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
func(prefix netip.Prefix, _ any) error {
|
||||
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
|
||||
},
|
||||
)
|
||||
|
||||
dm.allowedIPsRefCounter = refcounter.New(
|
||||
func(prefix netip.Prefix, peerKey string) (string, error) {
|
||||
// save peerKey to use it in the remove function
|
||||
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
|
||||
},
|
||||
func(prefix netip.Prefix, peerKey string) error {
|
||||
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
|
||||
if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
|
||||
return err
|
||||
}
|
||||
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.clientRoutes(initialRoutes)
|
||||
dm.notifier.setInitialClientRoutes(cr)
|
||||
@ -78,7 +118,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := cleanupRouting(); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
|
||||
@ -86,7 +126,7 @@ func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePee
|
||||
signalAddress := m.statusRecorder.GetSignalState().URL
|
||||
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
||||
|
||||
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
@ -110,8 +150,19 @@ func (m *DefaultManager) Stop() {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
|
||||
if m.routeRefCounter != nil {
|
||||
if err := m.routeRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing route ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
if m.allowedIPsRefCounter != nil {
|
||||
if err := m.allowedIPsRefCounter.Flush(); err != nil {
|
||||
log.Errorf("Error flushing allowed IPs ref counter: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
if err := m.sysOps.CleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
@ -185,7 +236,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
@ -197,7 +248,7 @@ func (m *DefaultManager) stopObsoleteClients(networks route.HAMap) {
|
||||
for id, client := range m.clientNetworks {
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
client.cancel()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
@ -210,7 +261,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
if !found {
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
}
|
||||
@ -228,7 +279,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
ownNetworkIDs := make(map[route.HAUniqueID]bool)
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
haID := route.GetHAUniqueID(newRoute)
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
if newRoute.Peer == m.pubKey {
|
||||
ownNetworkIDs[haID] = true
|
||||
// only linux is supported for now
|
||||
@ -241,9 +292,9 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
haID := route.GetHAUniqueID(newRoute)
|
||||
haID := newRoute.GetHAUniqueID()
|
||||
if !ownNetworkIDs[haID] {
|
||||
if !isPrefixSupported(newRoute.Network) {
|
||||
if !isRouteSupported(newRoute) {
|
||||
continue
|
||||
}
|
||||
newClientRoutesIDMap[haID] = append(newClientRoutesIDMap[haID], newRoute)
|
||||
@ -255,23 +306,23 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0)
|
||||
rs := make([]*route.Route, len(crMap))
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
}
|
||||
return rs
|
||||
}
|
||||
|
||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
func isRouteSupported(route *route.Route) bool {
|
||||
if !nbnet.CustomRoutingDisabled() || route.IsDynamic() {
|
||||
return true
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
// we skip this prefix management
|
||||
if prefix.Bits() <= minRangeBits {
|
||||
if route.Network.Bits() <= vars.MinRangeBits {
|
||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
||||
version.NetbirdVersion(), prefix)
|
||||
version.NetbirdVersion(), route.Network)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
|
155
client/internal/routemanager/refcounter/refcounter.go
Normal file
155
client/internal/routemanager/refcounter/refcounter.go
Normal file
@ -0,0 +1,155 @@
|
||||
package refcounter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
|
||||
var ErrIgnore = errors.New("ignore")
|
||||
|
||||
type Ref[O any] struct {
|
||||
Count int
|
||||
Out O
|
||||
}
|
||||
|
||||
type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
|
||||
type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
|
||||
|
||||
type Counter[I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for prefixes
|
||||
refCountMap map[netip.Prefix]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
// idMap keeps track of the prefixes associated with an ID for removal
|
||||
idMap map[string][]netip.Prefix
|
||||
idMu sync.Mutex
|
||||
add AddFunc[I, O]
|
||||
remove RemoveFunc[I, O]
|
||||
}
|
||||
|
||||
// New creates a new Counter instance
|
||||
func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
|
||||
return &Counter[I, O]{
|
||||
refCountMap: map[netip.Prefix]Ref[O]{},
|
||||
idMap: map[string][]netip.Prefix{},
|
||||
add: add,
|
||||
remove: remove,
|
||||
}
|
||||
}
|
||||
|
||||
// Increment increments the reference count for the given prefix.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
|
||||
// Call AddFunc only if it's a new prefix
|
||||
if ref.Count == 0 {
|
||||
log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
|
||||
out, err := rm.add(prefix, in)
|
||||
|
||||
if errors.Is(err, ErrIgnore) {
|
||||
return ref, nil
|
||||
}
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.Out = out
|
||||
}
|
||||
|
||||
ref.Count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
|
||||
// If this is the first reference to the prefix, the AddFunc is called.
|
||||
func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(prefix, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
rm.idMap[id] = append(rm.idMap[id], prefix)
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// Decrement decrements the reference count for the given prefix.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[prefix]
|
||||
if !ok {
|
||||
log.Tracef("No reference found for prefix %s", prefix)
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
|
||||
if ref.Count == 1 {
|
||||
log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.Count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
|
||||
return ref, nil
|
||||
}
|
||||
|
||||
// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(prefix); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
delete(rm.idMap, id)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each prefix.
|
||||
func (rm *Counter[I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Tracef("Removing for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.remove(prefix, ref.Out); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]Ref[O]{}
|
||||
|
||||
rm.idMap = map[string][]netip.Prefix{}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
7
client/internal/routemanager/refcounter/types.go
Normal file
7
client/internal/routemanager/refcounter/types.go
Normal file
@ -0,0 +1,7 @@
|
||||
package refcounter
|
||||
|
||||
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
|
||||
type RouteRefCounter = Counter[any, any]
|
||||
|
||||
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
|
||||
type AllowedIPsRefCounter = Counter[string, string]
|
@ -1,127 +0,0 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type ref struct {
|
||||
count int
|
||||
nexthop netip.Addr
|
||||
intf *net.Interface
|
||||
}
|
||||
|
||||
type RouteManager struct {
|
||||
// refCountMap keeps track of the reference ref for prefixes
|
||||
refCountMap map[netip.Prefix]ref
|
||||
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
||||
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
||||
addRoute AddRouteFunc
|
||||
removeRoute RemoveRouteFunc
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf *net.Interface, err error)
|
||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error
|
||||
|
||||
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
||||
// TODO: read initial routing table into refCountMap
|
||||
return &RouteManager{
|
||||
refCountMap: map[netip.Prefix]ref{},
|
||||
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
||||
addRoute: addRoute,
|
||||
removeRoute: removeRoute,
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
|
||||
// Add route to the system, only if it's a new prefix
|
||||
if ref.count == 0 {
|
||||
log.Debugf("Adding route for prefix %s", prefix)
|
||||
nexthop, intf, err := rm.addRoute(prefix)
|
||||
if errors.Is(err, ErrRouteNotFound) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, ErrRouteNotAllowed) {
|
||||
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
||||
}
|
||||
ref.nexthop = nexthop
|
||||
ref.intf = intf
|
||||
}
|
||||
|
||||
ref.count++
|
||||
rm.refCountMap[prefix] = ref
|
||||
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
prefixes, ok := rm.prefixMap[connID]
|
||||
if !ok {
|
||||
log.Debugf("No prefixes found for connection ID %s", connID)
|
||||
return nil
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
ref := rm.refCountMap[prefix]
|
||||
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
||||
if ref.count == 1 {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
// TODO: don't fail if the route is not found
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
continue
|
||||
}
|
||||
delete(rm.refCountMap, prefix)
|
||||
} else {
|
||||
ref.count--
|
||||
rm.refCountMap[prefix] = ref
|
||||
}
|
||||
}
|
||||
delete(rm.prefixMap, connID)
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// Flush removes all references and routes from the system
|
||||
func (rm *RouteManager) Flush() error {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
for prefix := range rm.refCountMap {
|
||||
log.Debugf("Removing route for prefix %s", prefix)
|
||||
ref := rm.refCountMap[prefix]
|
||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
||||
}
|
||||
}
|
||||
rm.refCountMap = map[netip.Prefix]ref{}
|
||||
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
@ -5,13 +5,14 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@ -70,7 +71,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
if len(m.routes) > 0 {
|
||||
err := enableIPForwarding()
|
||||
err := systemops.EnableIPForwarding()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -88,7 +89,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Network, route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
@ -117,7 +118,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Network, route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
@ -133,7 +134,13 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
state.Routes[route.Network.String()] = struct{}{}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
@ -144,7 +151,7 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, r := range m.routes {
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
|
||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Network, r)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert route to router pair: %v", err)
|
||||
continue
|
||||
@ -162,15 +169,17 @@ func (m *defaultServerRouter) cleanUp() {
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
||||
parsed, err := netip.ParsePrefix(source)
|
||||
if err != nil {
|
||||
return firewall.RouterPair{}, err
|
||||
func routeToRouterPair(source *net.IPNet, route *route.Route) (firewall.RouterPair, error) {
|
||||
destination := route.Network.Masked().String()
|
||||
if route.IsDynamic() {
|
||||
// TODO: add ipv6
|
||||
destination = "0.0.0.0/0"
|
||||
}
|
||||
|
||||
return firewall.RouterPair{
|
||||
ID: string(route.ID),
|
||||
Source: parsed.String(),
|
||||
Destination: route.Network.Masked().String(),
|
||||
Source: source.String(),
|
||||
Destination: destination,
|
||||
Masquerade: route.Masquerade,
|
||||
}, nil
|
||||
}
|
||||
|
57
client/internal/routemanager/static/route.go
Normal file
57
client/internal/routemanager/static/route.go
Normal file
@ -0,0 +1,57 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
route *route.Route
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
|
||||
}
|
||||
|
||||
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route {
|
||||
return &Route{
|
||||
route: rt,
|
||||
routeRefCounter: routeRefCounter,
|
||||
allowedIPsRefcounter: allowedIPsRefCounter,
|
||||
}
|
||||
}
|
||||
|
||||
// Route route methods
|
||||
func (r *Route) String() string {
|
||||
return r.route.Network.String()
|
||||
}
|
||||
|
||||
func (r *Route) AddRoute(context.Context) error {
|
||||
_, err := r.routeRefCounter.Increment(r.route.Network, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) RemoveRoute() error {
|
||||
_, err := r.routeRefCounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Route) AddAllowedIPs(peerKey string) error {
|
||||
if ref, err := r.allowedIPsRefcounter.Increment(r.route.Network, peerKey); err != nil {
|
||||
return fmt.Errorf("add allowed IP %s: %w", r.route.Network, err)
|
||||
} else if ref.Count > 1 && ref.Out != peerKey {
|
||||
log.Warnf("Prefix [%s] is already routed by peer [%s]. HA routing disabled",
|
||||
r.route.Network,
|
||||
ref.Out,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) RemoveAllowedIPs() error {
|
||||
_, err := r.allowedIPsRefcounter.Decrement(r.route.Network)
|
||||
return err
|
||||
}
|
103
client/internal/routemanager/sysctl/sysctl_linux.go
Normal file
103
client/internal/routemanager/sysctl/sysctl_linux.go
Normal file
@ -0,0 +1,103 @@
|
||||
// go:build !android
|
||||
package sysctl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
// Setup configures sysctl settings for RP filtering and source validation.
|
||||
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := Set(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = Set(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := Set(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// Set sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func Set(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
// Cleanup resets sysctl settings to their original values.
|
||||
func Cleanup(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := Set(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
18
client/internal/routemanager/systemops/routeflags_bsd.go
Normal file
18
client/internal/routemanager/systemops/routeflags_bsd.go
Normal file
@ -0,0 +1,18 @@
|
||||
//go:build darwin || dragonfly || netbsd || openbsd
|
||||
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
19
client/internal/routemanager/systemops/routeflags_freebsd.go
Normal file
19
client/internal/routemanager/systemops/routeflags_freebsd.go
Normal file
@ -0,0 +1,19 @@
|
||||
//go:build: freebsd
|
||||
package systemops
|
||||
|
||||
import "syscall"
|
||||
|
||||
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
|
||||
func filterRoutesByFlags(routeMessageFlags int) bool {
|
||||
if routeMessageFlags&syscall.RTF_UP == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
|
||||
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
|
||||
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
27
client/internal/routemanager/systemops/systemops.go
Normal file
27
client/internal/routemanager/systemops/systemops.go
Normal file
@ -0,0 +1,27 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type Nexthop struct {
|
||||
IP netip.Addr
|
||||
Intf *net.Interface
|
||||
}
|
||||
|
||||
type ExclusionCounter = refcounter.Counter[any, Nexthop]
|
||||
|
||||
type SysOps struct {
|
||||
refCounter *ExclusionCounter
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
|
||||
return &SysOps{
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -43,8 +43,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
|
||||
}
|
||||
|
||||
if m.Flags&syscall.RTF_UP == 0 ||
|
||||
m.Flags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
|
||||
if filterRoutesByFlags(m.Flags) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -101,6 +100,7 @@ func toNetIP(a route.Addr) netip.Addr {
|
||||
}
|
||||
}
|
||||
|
||||
// ones returns the number of leading ones in the mask.
|
||||
func ones(a route.Addr) (int, error) {
|
||||
switch t := a.(type) {
|
||||
case *route.Inet4Addr:
|
||||
@ -114,6 +114,7 @@ func ones(a route.Addr) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// MsgToRoute converts a route message to a Route.
|
||||
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
|
||||
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
|
||||
|
@ -1,6 +1,6 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"testing"
|
@ -1,6 +1,6 @@
|
||||
//go:build !ios
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -35,13 +35,15 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||
intf := &net.Interface{Name: "lo0"}
|
||||
|
||||
r := NewSysOps(nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1024; i++ {
|
||||
wg.Add(1)
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
||||
@ -57,7 +59,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
||||
go func(ip netip.Addr) {
|
||||
defer wg.Done()
|
||||
prefix := netip.PrefixFrom(ip, 32)
|
||||
if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil {
|
||||
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||
}
|
||||
}(baseIP)
|
@ -1,6 +1,6 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -15,7 +15,11 @@ import (
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
@ -25,29 +29,75 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRouteNotFound = errors.New("route not found")
|
||||
var ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
refCounter := refcounter.New(
|
||||
func(prefix netip.Prefix, _ any) (Nexthop, error) {
|
||||
initialNexthop := initialNextHopV4
|
||||
if prefix.Addr().Is6() {
|
||||
initialNexthop = initialNextHopV6
|
||||
}
|
||||
|
||||
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
|
||||
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
|
||||
log.Tracef("Adding for prefix %s: %v", prefix, err)
|
||||
// These errors are not critical but also we should not track and try to remove the routes either.
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
return nexthop, err
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses)
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupRefCounter() error {
|
||||
if r.refCounter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := r.refCounter.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: fix: for default our wg address now appears as the default gw
|
||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
addr := netip.IPv4Unspecified()
|
||||
if prefix.Addr().Is6() {
|
||||
addr = netip.IPv6Unspecified()
|
||||
}
|
||||
|
||||
defaultGateway, _, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("get existing route gateway: %s", err)
|
||||
}
|
||||
|
||||
if !prefix.Contains(defaultGateway) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
||||
if !prefix.Contains(nexthop.IP) {
|
||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
||||
if defaultGateway.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
||||
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
||||
if nexthop.IP.Is6() {
|
||||
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
||||
}
|
||||
|
||||
ok, err := existsInRouteTable(gatewayPrefix)
|
||||
@ -60,46 +110,264 @@ func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
gatewayHop, intf, err := GetNextHop(defaultGateway)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
nexthop, err = GetNextHop(nexthop.IP)
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||
return addToRouteTable(gatewayPrefix, gatewayHop, intf)
|
||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
||||
return r.addToRouteTable(gatewayPrefix, nexthop)
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
||||
exitNextHop := Nexthop{
|
||||
IP: nexthop.IP,
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||
exitNextHop = initialNextHop
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
|
||||
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
|
||||
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
nextHop := Nexthop{netip.Addr{}, intf}
|
||||
|
||||
if prefix == vars.Defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
} else if prefix == vars.Defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
||||
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to get route for %s: %v", ip, err)
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||
if gateway == nil {
|
||||
if preferredSrc == nil {
|
||||
return netip.Addr{}, nil, ErrRouteNotFound
|
||||
if runtime.GOOS == "freebsd" {
|
||||
return Nexthop{Intf: intf}, nil
|
||||
}
|
||||
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
if preferredSrc == nil {
|
||||
return Nexthop{}, vars.ErrRouteNotFound
|
||||
}
|
||||
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
||||
|
||||
addr, err := ipToAddr(preferredSrc, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
|
||||
}
|
||||
return addr.Unmap(), intf, nil
|
||||
return Nexthop{
|
||||
IP: addr.Unmap(),
|
||||
Intf: intf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
addr, err := ipToAddr(gateway, intf)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err)
|
||||
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
|
||||
}
|
||||
|
||||
return addr, intf, nil
|
||||
return Nexthop{
|
||||
IP: addr,
|
||||
Intf: intf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// converts a net.IP to a netip.Addr including the zone based on the passed interface
|
||||
@ -140,275 +408,9 @@ func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||
return false, fmt.Errorf("get routes from table: %w", err)
|
||||
}
|
||||
for _, tableRoute := range routes {
|
||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||
func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop netip.Addr, initialIntf *net.Interface) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
switch {
|
||||
case addr.IsLoopback(),
|
||||
addr.IsLinkLocalUnicast(),
|
||||
addr.IsLinkLocalMulticast(),
|
||||
addr.IsInterfaceLocalMulticast(),
|
||||
addr.IsUnspecified(),
|
||||
addr.IsMulticast():
|
||||
|
||||
return netip.Addr{}, nil, ErrRouteNotAllowed
|
||||
}
|
||||
|
||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
||||
nexthop, intf, err := GetNextHop(addr)
|
||||
if err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
||||
exitNextHop := nexthop
|
||||
exitIntf := intf
|
||||
|
||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, nil, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
||||
}
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop == vpnAddr || exitIntf != nil && exitIntf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
||||
exitNextHop = initialNextHop
|
||||
exitIntf = initialIntf
|
||||
}
|
||||
|
||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
||||
return netip.Addr{}, nil, fmt.Errorf("add route to table: %w", err)
|
||||
}
|
||||
|
||||
return exitNextHop, exitIntf, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
// in two /1 prefixes to avoid replacing the existing default route
|
||||
func genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
} else if prefix == defaultv6 {
|
||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
||||
}
|
||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
||||
}
|
||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return addNonExistingRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
||||
func addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
ok, err := existsInRouteTable(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("exists in route table: %w", err)
|
||||
}
|
||||
if ok {
|
||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, err = isSubRange(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sub range: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
err := addRouteForCurrentDefaultGateway(prefix)
|
||||
if err != nil {
|
||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return addToRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||
// it will remove the split /1 prefixes
|
||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if prefix == defaultv4 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
// TODO: remove once IPv6 is supported on the interface
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
} else if prefix == defaultv6 {
|
||||
var result *multierror.Error
|
||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
||||
}
|
||||
|
||||
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return &prefix, nil
|
||||
}
|
||||
|
||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
initialNextHopV4, initialIntfV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
||||
}
|
||||
initialNextHopV6, initialIntfV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
||||
}
|
||||
|
||||
*routeManager = NewRouteManager(
|
||||
func(prefix netip.Prefix) (netip.Addr, *net.Interface, error) {
|
||||
addr := prefix.Addr()
|
||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
||||
if addr.Is6() {
|
||||
nexthop, intf = initialNextHopV6, initialIntfV6
|
||||
}
|
||||
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
||||
},
|
||||
removeFromRouteTable,
|
||||
)
|
||||
|
||||
return setupHooks(*routeManager, initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
||||
if routeManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove hooks selectively
|
||||
nbnet.RemoveDialerHooks()
|
||||
nbnet.RemoveListenerHooks()
|
||||
|
||||
if err := routeManager.Flush(); err != nil {
|
||||
return fmt.Errorf("flush route manager: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := getPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
||||
}
|
||||
|
||||
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return result.ErrorOrNil()
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
||||
return beforeHook(connID, ip.IP)
|
||||
})
|
||||
|
||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -63,17 +63,20 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
_, _, err = setupRouting(nil, wgInterface)
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
|
||||
_, _, err = r.SetupRouting(nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
||||
|
||||
if testCase.shouldRouteToWireguard {
|
||||
@ -84,19 +87,19 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
exists, err := existsInRouteTable(testCase.prefix)
|
||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||
if exists && testCase.shouldRouteToWireguard {
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
||||
|
||||
prefixGateway, _, err := GetNextHop(testCase.prefix.Addr())
|
||||
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
|
||||
internetGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if testCase.shouldBeRemoved {
|
||||
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
|
||||
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
|
||||
} else {
|
||||
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
|
||||
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -104,11 +107,11 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetNextHop(t *testing.T) {
|
||||
gateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
if !gateway.IsValid() {
|
||||
if !nexthop.IP.IsValid() {
|
||||
t.Fatal("should return a gateway")
|
||||
}
|
||||
addresses, err := net.InterfaceAddrs()
|
||||
@ -130,24 +133,24 @@ func TestGetNextHop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
localIP, _, err := GetNextHop(testingPrefix.Addr())
|
||||
localIP, err := GetNextHop(testingPrefix.Addr())
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error: ", err)
|
||||
}
|
||||
if !localIP.IsValid() {
|
||||
if !localIP.IP.IsValid() {
|
||||
t.Fatal("should return a gateway for local network")
|
||||
}
|
||||
if localIP.String() == gateway.String() {
|
||||
t.Fatal("local ip should not match with gateway IP")
|
||||
if localIP.IP.String() == nexthop.IP.String() {
|
||||
t.Fatal("local IP should not match with gateway IP")
|
||||
}
|
||||
if localIP.String() != testingIP {
|
||||
t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String())
|
||||
if localIP.IP.String() != testingIP {
|
||||
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
defaultGateway, _, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultGateway: ", defaultGateway)
|
||||
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||
t.Log("defaultNexthop: ", defaultNexthop)
|
||||
if err != nil {
|
||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||
}
|
||||
@ -164,7 +167,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Should Not Add Route if overlaps with default gateway",
|
||||
prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"),
|
||||
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
|
||||
shouldAddRoute: false,
|
||||
},
|
||||
{
|
||||
@ -214,14 +217,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
r := NewSysOps(wgInterface)
|
||||
|
||||
// Prepare the environment
|
||||
if testCase.preExistingPrefix.IsValid() {
|
||||
err := addVPNRoute(testCase.preExistingPrefix, intf)
|
||||
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||
}
|
||||
|
||||
// Add the route
|
||||
err = addVPNRoute(testCase.prefix, intf)
|
||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err when adding route")
|
||||
|
||||
if testCase.shouldAddRoute {
|
||||
@ -231,7 +236,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
require.True(t, ok, "route should exist")
|
||||
|
||||
// remove route again if added
|
||||
err = removeVPNRoute(testCase.prefix, intf)
|
||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||
require.NoError(t, err, "should not return err")
|
||||
}
|
||||
|
||||
@ -343,65 +348,52 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
return wgInterface
|
||||
}
|
||||
|
||||
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
||||
t.Helper()
|
||||
|
||||
err := r.AddVPNRoute(prefix, intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = r.RemoveVPNRoute(prefix, intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
}
|
||||
|
||||
func setupTestEnv(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
setupDummyInterfacesAndRoutes(t)
|
||||
|
||||
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, wgIface.Close())
|
||||
assert.NoError(t, wgInterface.Close())
|
||||
})
|
||||
|
||||
_, _, err := setupRouting(nil, wgIface)
|
||||
r := NewSysOps(wgInterface)
|
||||
_, _, err := r.SetupRouting(nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
assert.NoError(t, r.CleanupRouting())
|
||||
})
|
||||
|
||||
index, err := net.InterfaceByName(wgIface.Name())
|
||||
index, err := net.InterfaceByName(wgInterface.Name())
|
||||
require.NoError(t, err, "InterfaceByName should not return err")
|
||||
intf := &net.Interface{Index: index.Index, Name: wgIface.Name()}
|
||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
||||
|
||||
// default route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
|
||||
|
||||
// 10.0.0.0/8 route exists in main table and vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
|
||||
|
||||
// 10.10.0.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||
|
||||
// 127.0.10.0/24 more specific route exists in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
|
||||
|
||||
// unique route in vpn table
|
||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
require.NoError(t, err, "addVPNRoute should not return err")
|
||||
t.Cleanup(func() {
|
||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
||||
})
|
||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||
}
|
||||
|
||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
||||
@ -410,11 +402,11 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf
|
||||
return
|
||||
}
|
||||
|
||||
prefixGateway, _, err := GetNextHop(prefix.Addr())
|
||||
prefixNexthop, err := GetNextHop(prefix.Addr())
|
||||
require.NoError(t, err, "GetNextHop should not return err")
|
||||
if invert {
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
|
||||
} else {
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
|
||||
}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -9,16 +9,16 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@ -33,16 +33,10 @@ const (
|
||||
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
|
||||
var routeManager = &RouteManager{}
|
||||
|
||||
// originalSysctl stores the original sysctl values before they are modified
|
||||
var originalSysctl map[string]int
|
||||
|
||||
@ -82,7 +76,7 @@ func getSetupRules() []ruleParams {
|
||||
}
|
||||
}
|
||||
|
||||
// setupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// SetupRouting establishes the routing configuration for the VPN, including essential rules
|
||||
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
|
||||
//
|
||||
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
|
||||
@ -92,17 +86,17 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
if isLegacy() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := setupSysctl(wgIface)
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
@ -111,7 +105,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||
if cleanErr := r.CleanupRouting(); cleanErr != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||
}
|
||||
}
|
||||
@ -123,7 +117,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
||||
setIsLegacy(true)
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
@ -132,12 +126,12 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||
func cleanupRouting() error {
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
if isLegacy() {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
var result *multierror.Error
|
||||
@ -156,58 +150,58 @@ func cleanupRouting() error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||
if err := sysctl.Cleanup(originalSysctl); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||
}
|
||||
originalSysctl = nil
|
||||
sysctlFailed = false
|
||||
|
||||
return result.ErrorOrNil()
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
|
||||
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||
}
|
||||
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add blackhole: %w", err)
|
||||
}
|
||||
}
|
||||
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("add route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
if isLegacy() {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
if prefix == defaultv4 {
|
||||
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
||||
if prefix == vars.Defaultv4 {
|
||||
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove unreachable route: %w", err)
|
||||
}
|
||||
}
|
||||
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
||||
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
@ -255,7 +249,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
||||
}
|
||||
|
||||
// addRoute adds a route to a specific routing table identified by tableID.
|
||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
route := &netlink.Route{
|
||||
Scope: netlink.SCOPE_UNIVERSE,
|
||||
Table: tableID,
|
||||
@ -268,7 +262,7 @@ func addRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID
|
||||
}
|
||||
route.Dst = ipNet
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
if err := addNextHop(nexthop, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
@ -327,7 +321,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
}
|
||||
|
||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tableID int) error {
|
||||
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
||||
@ -340,7 +334,7 @@ func removeRoute(prefix netip.Prefix, addr netip.Addr, intf *net.Interface, tabl
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := addNextHop(addr, intf, route); err != nil {
|
||||
if err := addNextHop(nexthop, route); err != nil {
|
||||
return fmt.Errorf("add gateway and device: %w", err)
|
||||
}
|
||||
|
||||
@ -373,11 +367,11 @@ func flushRoutes(tableID, family int) error {
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||
func EnableIPForwarding() error {
|
||||
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -481,19 +475,19 @@ func removeRule(params ruleParams) error {
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(addr netip.Addr, intf *net.Interface, route *netlink.Route) error {
|
||||
if intf != nil {
|
||||
route.LinkIndex = intf.Index
|
||||
func addNextHop(nexthop Nexthop, route *netlink.Route) error {
|
||||
if nexthop.Intf != nil {
|
||||
route.LinkIndex = nexthop.Intf.Index
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
route.Gw = addr.AsSlice()
|
||||
if nexthop.IP.IsValid() {
|
||||
route.Gw = nexthop.IP.AsSlice()
|
||||
|
||||
// if zone is set, it means the gateway is a link-local address, so we set the link index
|
||||
if addr.Zone() != "" && intf == nil {
|
||||
link, err := netlink.LinkByName(addr.Zone())
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
link, err := netlink.LinkByName(nexthop.IP.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get link by name for zone %s: %w", addr.Zone(), err)
|
||||
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
|
||||
}
|
||||
route.LinkIndex = link.Attrs().Index
|
||||
}
|
||||
@ -508,83 +502,3 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
}
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := setSysctl(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
func cleanupSysctl(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := setSysctl(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build !android
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -14,6 +14,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
var expectedVPNint = "wgtest0"
|
||||
@ -138,7 +140,7 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
||||
if dstIPNet.String() == "0.0.0.0/0" {
|
||||
var err error
|
||||
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
t.Logf("Failed to fetch original gateway: %v", err)
|
||||
}
|
||||
|
||||
@ -193,7 +195,7 @@ func fetchOriginalGateway(family int) (net.IP, int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, 0, ErrRouteNotFound
|
||||
return nil, 0, vars.ErrRouteNotFound
|
||||
}
|
||||
|
||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
34
client/internal/routemanager/systemops/systemops_mobile.go
Normal file
34
client/internal/routemanager/systemops/systemops_mobile.go
Normal file
@ -0,0 +1,34 @@
|
||||
//go:build ios || android
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
24
client/internal/routemanager/systemops/systemops_nonlinux.go
Normal file
24
client/internal/routemanager/systemops/systemops_nonlinux.go
Normal file
@ -0,0 +1,24 @@
|
||||
//go:build !linux && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return r.genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return r.genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func EnableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build darwin && !ios
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -14,42 +14,40 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("add", prefix, nexthop, intf)
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("add", prefix, nexthop)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
return routeCmd("delete", prefix, nexthop, intf)
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return r.routeCmd("delete", prefix, nexthop)
|
||||
}
|
||||
|
||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
|
||||
inet := "-inet"
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
}
|
||||
|
||||
network := prefix.String()
|
||||
if prefix.IsSingleIP() {
|
||||
network = prefix.Addr().String()
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
} else if intf != nil {
|
||||
args = append(args, "-interface", intf.Name)
|
||||
if nexthop.IP.IsValid() {
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
} else if nexthop.Intf != nil {
|
||||
args = append(args, "-interface", nexthop.Intf.Name)
|
||||
}
|
||||
|
||||
if err := retryRouteCmd(args); err != nil {
|
@ -1,6 +1,6 @@
|
||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -18,7 +18,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type MSFT_NetRoute struct {
|
||||
@ -57,14 +56,43 @@ var prefixList []netip.Prefix
|
||||
var lastUpdate time.Time
|
||||
var mux = sync.Mutex{}
|
||||
|
||||
var routeManager *RouteManager
|
||||
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses)
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return cleanupRoutingWithRouteManager(routeManager)
|
||||
func (r *SysOps) CleanupRouting() error {
|
||||
return r.cleanupRefCounter()
|
||||
}
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.IP.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
nexthop.Intf = &net.Interface{Index: zone}
|
||||
nexthop.IP.WithZone("")
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IP.IsValid() {
|
||||
nexthop.IP.WithZone("")
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
@ -93,7 +121,7 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||
func GetRoutes() ([]Route, error) {
|
||||
var entries []MSFT_NetRoute
|
||||
|
||||
query := `SELECT DestinationPrefix, NextHop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
|
||||
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
|
||||
return nil, fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
@ -157,11 +185,11 @@ func GetNeighbors() ([]Neighbor, error) {
|
||||
return neighbors, nil
|
||||
}
|
||||
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
args := []string{"add", prefix.String()}
|
||||
|
||||
if nexthop.IsValid() {
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
if nexthop.IP.IsValid() {
|
||||
args = append(args, nexthop.IP.Unmap().String())
|
||||
} else {
|
||||
addr := "0.0.0.0"
|
||||
if prefix.Addr().Is6() {
|
||||
@ -170,8 +198,8 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
|
||||
args = append(args, addr)
|
||||
}
|
||||
|
||||
if intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(intf.Index))
|
||||
if nexthop.Intf != nil {
|
||||
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
@ -185,37 +213,6 @@ func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf *net.Interface) error {
|
||||
if nexthop.Zone() != "" && intf == nil {
|
||||
zone, err := strconv.Atoi(nexthop.Zone())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid zone: %w", err)
|
||||
}
|
||||
intf = &net.Interface{Index: zone}
|
||||
nexthop.WithZone("")
|
||||
}
|
||||
|
||||
return addRouteCmd(prefix, nexthop, intf)
|
||||
}
|
||||
|
||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ *net.Interface) error {
|
||||
args := []string{"delete", prefix.String()}
|
||||
if nexthop.IsValid() {
|
||||
nexthop.WithZone("")
|
||||
args = append(args, nexthop.Unmap().String())
|
||||
}
|
||||
|
||||
routeCmd := uspfilter.GetSystem32Command("route")
|
||||
|
||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove route: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCacheDisabled() bool {
|
||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package routemanager
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -29,7 +29,7 @@ type FindNetRouteOutput struct {
|
||||
InterfaceIndex int `json:"InterfaceIndex"`
|
||||
InterfaceAlias string `json:"InterfaceAlias"`
|
||||
AddressFamily int `json:"AddressFamily"`
|
||||
NextHop string `json:"NextHop"`
|
||||
NextHop string `json:"Nexthop"`
|
||||
DestinationPrefix string `json:"DestinationPrefix"`
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
|
||||
host, _, err := net.SplitHostPort(destination)
|
||||
require.NoError(t, err)
|
||||
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
|
||||
|
||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
||||
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
||||
@ -207,7 +207,7 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str
|
||||
}
|
||||
|
||||
func fetchOriginalGateway() (*RouteInfo, error) {
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
@ -1,33 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func cleanupRouting() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeVPNRoute(netip.Prefix, *net.Interface) error {
|
||||
return nil
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
//go:build !linux && !ios
|
||||
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func enableIPForwarding() error {
|
||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
|
||||
func addVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
func removeVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return genericRemoveVPNRoute(prefix, intf)
|
||||
}
|
29
client/internal/routemanager/util/ip.go
Normal file
29
client/internal/routemanager/util/ip.go
Normal file
@ -0,0 +1,29 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// GetPrefixFromIP returns a netip.Prefix from a net.IP address.
|
||||
func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
|
||||
var prefixLength int
|
||||
switch {
|
||||
case addr.Is4():
|
||||
prefixLength = 32
|
||||
case addr.Is6():
|
||||
prefixLength = 128
|
||||
default:
|
||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
||||
return prefix, nil
|
||||
}
|
16
client/internal/routemanager/vars/vars.go
Normal file
16
client/internal/routemanager/vars/vars.go
Normal file
@ -0,0 +1,16 @@
|
||||
package vars
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const MinRangeBits = 7
|
||||
|
||||
var (
|
||||
ErrRouteNotFound = errors.New("route not found")
|
||||
ErrRouteNotAllowed = errors.New("route not allowed")
|
||||
|
||||
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
)
|
@ -3,11 +3,11 @@ package routeselector
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
route "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -30,10 +30,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
var err *multierror.Error
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
|
||||
@ -41,11 +41,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
}
|
||||
rs.selectAll = false
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
return errors.FormatErrorOrNil(err)
|
||||
}
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
@ -65,21 +61,17 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
||||
}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
var err *multierror.Error
|
||||
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
delete(rs.selectedRoutes, route)
|
||||
}
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
return errors.FormatErrorOrNil(err)
|
||||
}
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
@ -111,18 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
|
@ -261,15 +261,15 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := route.HAMap{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route3-172.16.0.0/12": {},
|
||||
"route1|10.0.0.0/8": {},
|
||||
"route2|192.168.0.0/16": {},
|
||||
"route3|172.16.0.0/12": {},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelected(routes)
|
||||
|
||||
assert.Equal(t, route.HAMap{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route1|10.0.0.0/8": {},
|
||||
"route2|192.168.0.0/16": {},
|
||||
}, filtered)
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -92,6 +92,8 @@ message LoginRequest {
|
||||
repeated string extraIFaceBlacklist = 17;
|
||||
|
||||
optional bool networkMonitor = 18;
|
||||
|
||||
optional google.protobuf.Duration dnsRouteInterval = 19;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@ -233,10 +235,17 @@ message SelectRoutesRequest {
|
||||
message SelectRoutesResponse {
|
||||
}
|
||||
|
||||
message IPList {
|
||||
repeated string ips = 1;
|
||||
}
|
||||
|
||||
|
||||
message Route {
|
||||
string ID = 1;
|
||||
string network = 2;
|
||||
bool selected = 3;
|
||||
repeated string domains = 4;
|
||||
map<string, IPList> resolvedIPs = 5;
|
||||
}
|
||||
|
||||
message DebugBundleRequest {
|
||||
|
@ -9,17 +9,19 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
NetID route.NetID
|
||||
Network netip.Prefix
|
||||
Domains domain.List
|
||||
Selected bool
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of all available routes.
|
||||
func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
|
||||
func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
@ -43,6 +45,7 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
||||
route := &selectRoute{
|
||||
NetID: id,
|
||||
Network: rt[0].Network,
|
||||
Domains: rt[0].Domains,
|
||||
Selected: routeSelector.IsSelected(id),
|
||||
}
|
||||
routes = append(routes, route)
|
||||
@ -63,13 +66,29 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) (
|
||||
return iPrefix < jPrefix
|
||||
})
|
||||
|
||||
resolvedDomains := s.statusRecorder.GetResolvedDomainsStates()
|
||||
var pbRoutes []*proto.Route
|
||||
for _, route := range routes {
|
||||
pbRoutes = append(pbRoutes, &proto.Route{
|
||||
ID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Selected: route.Selected,
|
||||
})
|
||||
pbRoute := &proto.Route{
|
||||
ID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToSafeStringList(),
|
||||
ResolvedIPs: map[string]*proto.IPList{},
|
||||
Selected: route.Selected,
|
||||
}
|
||||
|
||||
for _, domain := range route.Domains {
|
||||
if prefixes, exists := resolvedDomains[domain]; exists {
|
||||
var ipStrings []string
|
||||
for _, prefix := range prefixes {
|
||||
ipStrings = append(ipStrings, prefix.Addr().String())
|
||||
}
|
||||
pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{
|
||||
Ips: ipStrings,
|
||||
}
|
||||
}
|
||||
}
|
||||
pbRoutes = append(pbRoutes, pbRoute)
|
||||
}
|
||||
|
||||
return &proto.ListRoutesResponse{
|
||||
|
@ -365,6 +365,12 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
}
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
duration := msg.DnsRouteInterval.AsDuration()
|
||||
inputConfig.DNSRouteInterval = &duration
|
||||
s.latestConfigInput.DNSRouteInterval = &duration
|
||||
}
|
||||
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
|
10
client/ssh/window_freebsd.go
Normal file
10
client/ssh/window_freebsd.go
Normal file
@ -0,0 +1,10 @@
|
||||
//go:build freebsd
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func setWinSize(file *os.File, width, height int) {
|
||||
}
|
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@ -33,6 +34,12 @@ type Environment struct {
|
||||
Platform string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
Path string
|
||||
Exist bool
|
||||
ProcessIsRunning bool
|
||||
}
|
||||
|
||||
// Info is an object that contains machine information
|
||||
// Most of the code is taken from https://github.com/matishsiao/goInfo
|
||||
type Info struct {
|
||||
@ -51,6 +58,7 @@ type Info struct {
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
Files []File
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
@ -132,3 +140,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
processCheckPaths := make([]string, 0)
|
||||
for _, check := range checks {
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
}
|
||||
|
||||
files, err := checkFileAndProcess(processCheckPaths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := GetInfo(ctx)
|
||||
info.Files = files
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
@ -44,6 +44,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
func uname() []string {
|
||||
res := run("/system/bin/uname", "-a")
|
||||
return strings.Split(res, " ")
|
||||
|
@ -1,15 +1,18 @@
|
||||
//go:build freebsd
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@ -22,8 +25,8 @@ func GetInfo(ctx context.Context) *Info {
|
||||
out = _getInfo()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
osStr := strings.Replace(out, "\n", "", -1)
|
||||
osStr = strings.Replace(osStr, "\r\n", "", -1)
|
||||
osStr := strings.ReplaceAll(out, "\n", "")
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
|
||||
env := Environment{
|
||||
@ -31,14 +34,23 @@ func GetInfo(ctx context.Context) *Info {
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{Kernel: osInfo[0], Platform: runtime.GOARCH, OS: osInfo[2], GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: osInfo[1], Environment: env}
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
gio.Hostname = extractDeviceName(ctx, systemHostname)
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
|
||||
return gio
|
||||
return &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: osInfo[0],
|
||||
Platform: runtime.GOARCH,
|
||||
OS: osName,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
CPUs: runtime.NumCPU(),
|
||||
WiretrusteeVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
Environment: env,
|
||||
}
|
||||
}
|
||||
|
||||
func _getInfo() string {
|
||||
@ -50,7 +62,8 @@ func _getInfo() string {
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
fmt.Println("getInfo:", err)
|
||||
log.Warnf("getInfo: %s", err)
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
// extractOsVersion extracts operating system version from context or returns the default
|
||||
func extractOsVersion(ctx context.Context, defaultName string) string {
|
||||
v, ok := ctx.Value(OsVersionCtxKey).(string)
|
||||
|
@ -28,28 +28,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
releaseInfo := _getReleaseInfo()
|
||||
for strings.Contains(info, "broken pipe") {
|
||||
releaseInfo = _getReleaseInfo()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
osRelease := strings.Split(releaseInfo, "\n")
|
||||
var osName string
|
||||
var osVer string
|
||||
for _, s := range osRelease {
|
||||
if strings.HasPrefix(s, "NAME=") {
|
||||
osName = strings.Split(s, "=")[1]
|
||||
osName = strings.ReplaceAll(osName, "\"", "")
|
||||
} else if strings.HasPrefix(s, "VERSION_ID=") {
|
||||
osVer = strings.Split(s, "=")[1]
|
||||
osVer = strings.ReplaceAll(osVer, "\"", "")
|
||||
}
|
||||
}
|
||||
|
||||
osStr := strings.ReplaceAll(info, "\n", "")
|
||||
osStr = strings.ReplaceAll(osStr, "\r\n", "")
|
||||
osInfo := strings.Split(osStr, " ")
|
||||
|
||||
osName, osVersion := readOsReleaseFile()
|
||||
if osName == "" {
|
||||
osName = osInfo[3]
|
||||
}
|
||||
@ -72,7 +55,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
Kernel: osInfo[0],
|
||||
Platform: osInfo[2],
|
||||
OS: osName,
|
||||
OSVersion: osVer,
|
||||
OSVersion: osVersion,
|
||||
Hostname: extractDeviceName(ctx, systemHostname),
|
||||
GoOS: runtime.GOOS,
|
||||
CPUs: runtime.NumCPU(),
|
||||
@ -103,20 +86,6 @@ func _getInfo() string {
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func _getReleaseInfo() string {
|
||||
cmd := exec.Command("cat", "/etc/os-release")
|
||||
cmd.Stdin = strings.NewReader("some")
|
||||
var out bytes.Buffer
|
||||
var stderr bytes.Buffer
|
||||
cmd.Stdout = &out
|
||||
cmd.Stderr = &stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Warnf("geucwReleaseInfo: %s", err)
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var si sysinfo.SysInfo
|
||||
si.GetSysInfo()
|
||||
|
38
client/system/osrelease_unix.go
Normal file
38
client/system/osrelease_unix.go
Normal file
@ -0,0 +1,38 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func readOsReleaseFile() (osName string, osVer string) {
|
||||
file, err := os.Open("/etc/os-release")
|
||||
if err != nil {
|
||||
log.Warnf("failed to open file /etc/os-release: %s", err)
|
||||
return "", ""
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "NAME=") {
|
||||
osName = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "VERSION_ID=") {
|
||||
osVer = strings.ReplaceAll(strings.Split(line, "=")[1], "\"", "")
|
||||
continue
|
||||
}
|
||||
|
||||
if osName != "" && osVer != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
58
client/system/process.go
Normal file
58
client/system/process.go
Normal file
@ -0,0 +1,58 @@
|
||||
//go:build windows || (linux && !android) || (darwin && !ios)
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
// getRunningProcesses returns a list of running process paths.
|
||||
func getRunningProcesses() ([]string, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processMap := make(map[string]bool)
|
||||
for _, p := range processes {
|
||||
path, _ := p.Exe()
|
||||
if path != "" {
|
||||
processMap[path] = true
|
||||
}
|
||||
}
|
||||
|
||||
uniqueProcesses := make([]string, 0, len(processMap))
|
||||
for p := range processMap {
|
||||
uniqueProcesses = append(uniqueProcesses, p)
|
||||
}
|
||||
|
||||
return uniqueProcesses, nil
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
files := make([]File, len(paths))
|
||||
if len(paths) == 0 {
|
||||
return files, nil
|
||||
}
|
||||
|
||||
runningProcesses, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
file := File{Path: path}
|
||||
|
||||
_, err := os.Stat(path)
|
||||
file.Exist = !os.IsNotExist(err)
|
||||
|
||||
file.ProcessIsRunning = slices.Contains(runningProcesses, path)
|
||||
files[i] = file
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
//go:build !(linux && 386)
|
||||
//go:build !(linux && 386) && !freebsd
|
||||
|
||||
package main
|
||||
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
func (s *serviceClient) showRoutesUI() {
|
||||
s.wRoutes = s.app.NewWindow("NetBird Routes")
|
||||
|
||||
grid := container.New(layout.NewGridLayout(2))
|
||||
grid := container.New(layout.NewGridLayout(3))
|
||||
go s.updateRoutes(grid)
|
||||
routeCheckContainer := container.NewVBox()
|
||||
routeCheckContainer.Add(grid)
|
||||
@ -61,14 +61,16 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
|
||||
grid.Objects = nil
|
||||
idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
networkHeader := widget.NewLabelWithStyle("Network", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
|
||||
grid.Add(idHeader)
|
||||
grid.Add(networkHeader)
|
||||
grid.Add(resolvedIPsHeader)
|
||||
for _, route := range routes {
|
||||
r := route
|
||||
|
||||
checkBox := widget.NewCheck(r.ID, func(checked bool) {
|
||||
checkBox := widget.NewCheck(r.GetID(), func(checked bool) {
|
||||
s.selectRoute(r.ID, checked)
|
||||
})
|
||||
checkBox.Checked = route.Selected
|
||||
@ -76,10 +78,31 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container) {
|
||||
checkBox.Refresh()
|
||||
|
||||
grid.Add(checkBox)
|
||||
grid.Add(widget.NewLabel(r.Network))
|
||||
network := r.GetNetwork()
|
||||
domains := r.GetDomains()
|
||||
if len(domains) > 0 {
|
||||
network = strings.Join(domains, ", ")
|
||||
}
|
||||
grid.Add(widget.NewLabel(network))
|
||||
|
||||
if len(domains) > 0 {
|
||||
var resolvedIPsList []string
|
||||
for _, domain := range r.GetDomains() {
|
||||
if ipList, exists := r.GetResolvedIPs()[domain]; exists {
|
||||
resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", ")))
|
||||
}
|
||||
}
|
||||
// TODO: limit width
|
||||
resolvedIPsLabel := widget.NewLabel(strings.Join(resolvedIPsList, ", "))
|
||||
grid.Add(resolvedIPsLabel)
|
||||
} else {
|
||||
grid.Add(widget.NewLabel(""))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
s.wRoutes.Content().Refresh()
|
||||
grid.Refresh()
|
||||
}
|
||||
|
||||
func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) {
|
||||
|
2
go.mod
2
go.mod
@ -68,6 +68,7 @@ require (
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/testcontainers/testcontainers-go v0.31.0
|
||||
@ -176,7 +177,6 @@ require (
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.53.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.0 // indirect
|
||||
github.com/shirou/gopsutil/v3 v3.24.4 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
|
||||
|
@ -23,24 +23,6 @@ func parseWGAddress(address string) (WGAddress, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Masked returns the WGAddress with the IP address part masked according to its network mask.
|
||||
func (addr WGAddress) Masked() WGAddress {
|
||||
ip := addr.IP.To4()
|
||||
if ip == nil {
|
||||
ip = addr.IP.To16()
|
||||
}
|
||||
|
||||
maskedIP := make(net.IP, len(ip))
|
||||
for i := range ip {
|
||||
maskedIP[i] = ip[i] & addr.Network.Mask[i]
|
||||
}
|
||||
|
||||
return WGAddress{
|
||||
IP: maskedIP,
|
||||
Network: addr.Network,
|
||||
}
|
||||
}
|
||||
|
||||
func (addr WGAddress) String() string {
|
||||
maskSize, _ := addr.Network.Mask.Size()
|
||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||
|
8
iface/freebsd/errors.go
Normal file
8
iface/freebsd/errors.go
Normal file
@ -0,0 +1,8 @@
|
||||
package freebsd
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDoesNotExist = errors.New("does not exist")
|
||||
ErrNameDoesNotMatch = errors.New("name does not match")
|
||||
)
|
108
iface/freebsd/iface.go
Normal file
108
iface/freebsd/iface.go
Normal file
@ -0,0 +1,108 @@
|
||||
package freebsd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type iface struct {
|
||||
Name string
|
||||
MTU int
|
||||
Group string
|
||||
IPAddrs []string
|
||||
}
|
||||
|
||||
func parseError(output []byte) error {
|
||||
// TODO: implement without allocations
|
||||
lines := string(output)
|
||||
|
||||
if strings.Contains(lines, "does not exist") {
|
||||
return ErrDoesNotExist
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseIfconfigOutput(output []byte) (*iface, error) {
|
||||
// TODO: implement without allocations
|
||||
lines := string(output)
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(lines))
|
||||
|
||||
var name, mtu, group string
|
||||
var ips []string
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// If line contains ": flags", it's a line with interface information
|
||||
if strings.Contains(line, ": flags") {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 4 {
|
||||
return nil, fmt.Errorf("failed to parse line: %s", line)
|
||||
}
|
||||
name = strings.TrimSuffix(parts[0], ":")
|
||||
if strings.Contains(line, "mtu") {
|
||||
mtuIndex := 0
|
||||
for i, part := range parts {
|
||||
if part == "mtu" {
|
||||
mtuIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
mtu = parts[mtuIndex+1]
|
||||
}
|
||||
}
|
||||
|
||||
// If line contains "groups:", it's a line with interface group
|
||||
if strings.Contains(line, "groups:") {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("failed to parse line: %s", line)
|
||||
}
|
||||
group = parts[1]
|
||||
}
|
||||
|
||||
// If line contains "inet ", it's a line with IP address
|
||||
if strings.Contains(line, "inet ") {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("failed to parse line: %s", line)
|
||||
}
|
||||
ips = append(ips, parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("interface name not found in ifconfig output")
|
||||
}
|
||||
|
||||
mtuInt, err := strconv.Atoi(mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MTU: %w", err)
|
||||
}
|
||||
|
||||
return &iface{
|
||||
Name: name,
|
||||
MTU: mtuInt,
|
||||
Group: group,
|
||||
IPAddrs: ips,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseIFName(output []byte) (string, error) {
|
||||
// TODO: implement without allocations
|
||||
lines := strings.Split(string(output), "\n")
|
||||
if len(lines) == 0 || lines[0] == "" {
|
||||
return "", fmt.Errorf("no output returned")
|
||||
}
|
||||
|
||||
fields := strings.Fields(lines[0])
|
||||
if len(fields) > 1 {
|
||||
return "", fmt.Errorf("invalid output")
|
||||
}
|
||||
|
||||
return fields[0], nil
|
||||
}
|
76
iface/freebsd/iface_internal_test.go
Normal file
76
iface/freebsd/iface_internal_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
package freebsd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseIfconfigOutput(t *testing.T) {
|
||||
testOutput := `wg1: flags=8080<NOARP,MULTICAST> metric 0 mtu 1420
|
||||
options=80000<LINKSTATE>
|
||||
groups: wg
|
||||
nd6 options=109<PERFORMNUD,IFDISABLED,NO_DAD>`
|
||||
|
||||
expected := &iface{
|
||||
Name: "wg1",
|
||||
MTU: 1420,
|
||||
Group: "wg",
|
||||
}
|
||||
|
||||
result, err := parseIfconfigOutput(([]byte)(testOutput))
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing ifconfig output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, expected.Name, result.Name, "Name should match")
|
||||
assert.Equal(t, expected.MTU, result.MTU, "MTU should match")
|
||||
assert.Equal(t, expected.Group, result.Group, "Group should match")
|
||||
}
|
||||
|
||||
func TestParseIFName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "ValidOutput",
|
||||
output: "eth0\n",
|
||||
expected: "eth0",
|
||||
},
|
||||
{
|
||||
name: "ValidOutputOneLine",
|
||||
output: "eth0",
|
||||
expected: "eth0",
|
||||
},
|
||||
{
|
||||
name: "EmptyOutput",
|
||||
output: "",
|
||||
expectedErr: fmt.Errorf("no output returned"),
|
||||
},
|
||||
{
|
||||
name: "InvalidOutput",
|
||||
output: "This is an invalid output\n",
|
||||
expectedErr: fmt.Errorf("invalid output"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result, err := parseIFName(([]byte)(test.output))
|
||||
|
||||
assert.Equal(t, test.expected, result, "Interface names should match")
|
||||
|
||||
if test.expectedErr != nil {
|
||||
assert.NotNil(t, err, "Error should not be nil")
|
||||
assert.EqualError(t, err, test.expectedErr.Error(), "Error messages should match")
|
||||
} else {
|
||||
assert.Nil(t, err, "Error should be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
239
iface/freebsd/link.go
Normal file
239
iface/freebsd/link.go
Normal file
@ -0,0 +1,239 @@
|
||||
package freebsd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const wgIFGroup = "wg"
|
||||
|
||||
// Link represents a network interface.
|
||||
type Link struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func NewLink(name string) *Link {
|
||||
return &Link{
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// LinkByName retrieves a network interface by its name.
|
||||
func LinkByName(name string) (*Link, error) {
|
||||
out, err := exec.Command("ifconfig", name).CombinedOutput()
|
||||
if err != nil {
|
||||
if pErr := parseError(out); pErr != nil {
|
||||
return nil, pErr
|
||||
}
|
||||
|
||||
log.Debugf("ifconfig out: %s", out)
|
||||
|
||||
return nil, fmt.Errorf("command run: %w", err)
|
||||
}
|
||||
|
||||
i, err := parseIfconfigOutput(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ifconfig output: %w", err)
|
||||
}
|
||||
|
||||
if i.Name != name {
|
||||
return nil, ErrNameDoesNotMatch
|
||||
}
|
||||
|
||||
return &Link{name: i.Name}, nil
|
||||
}
|
||||
|
||||
// Recreate - create new interface, remove current before create if it exists
|
||||
func (l *Link) Recreate() error {
|
||||
ok, err := l.isExist()
|
||||
if err != nil {
|
||||
return fmt.Errorf("is exist: %w", err)
|
||||
}
|
||||
|
||||
if ok {
|
||||
if err := l.del(l.name); err != nil {
|
||||
return fmt.Errorf("del: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return l.Add()
|
||||
}
|
||||
|
||||
// Add creates a new network interface.
|
||||
func (l *Link) Add() error {
|
||||
parsedName, err := l.create(wgIFGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create link: %w", err)
|
||||
}
|
||||
|
||||
if parsedName == l.name {
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedName, err = l.rename(parsedName, l.name)
|
||||
if err != nil {
|
||||
errDel := l.del(parsedName)
|
||||
if errDel != nil {
|
||||
return fmt.Errorf("del on rename link: %w: %w", err, errDel)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rename link: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Del removes an existing network interface.
|
||||
func (l *Link) Del() error {
|
||||
return l.del(l.name)
|
||||
}
|
||||
|
||||
// SetMTU sets the MTU of the network interface.
|
||||
func (l *Link) SetMTU(mtu int) error {
|
||||
return l.setMTU(mtu)
|
||||
}
|
||||
|
||||
// AssignAddr assigns an IP address and netmask to the network interface.
|
||||
func (l *Link) AssignAddr(ip, netmask string) error {
|
||||
return l.setAddr(ip, netmask)
|
||||
}
|
||||
|
||||
func (l *Link) Up() error {
|
||||
return l.up(l.name)
|
||||
}
|
||||
|
||||
func (l *Link) Down() error {
|
||||
return l.down(l.name)
|
||||
}
|
||||
|
||||
func (l *Link) isExist() (bool, error) {
|
||||
_, err := LinkByName(l.name)
|
||||
if errors.Is(err, ErrDoesNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("link by name: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (l *Link) create(groupName string) (string, error) {
|
||||
cmd := exec.Command("ifconfig", groupName, "create")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Debugf("ifconfig out: %s", output)
|
||||
|
||||
return "", fmt.Errorf("create %s interface: %w", groupName, err)
|
||||
}
|
||||
|
||||
interfaceName, err := parseIFName(output)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse interface name: %w", err)
|
||||
}
|
||||
|
||||
return interfaceName, nil
|
||||
}
|
||||
|
||||
func (l *Link) rename(oldName, newName string) (string, error) {
|
||||
cmd := exec.Command("ifconfig", oldName, "name", newName)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Debugf("ifconfig out: %s", output)
|
||||
|
||||
return "", fmt.Errorf("change name %q -> %q: %w", oldName, newName, err)
|
||||
}
|
||||
|
||||
interfaceName, err := parseIFName(output)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse new name: %w", err)
|
||||
}
|
||||
|
||||
return interfaceName, nil
|
||||
}
|
||||
|
||||
func (l *Link) del(name string) error {
|
||||
var stderr bytes.Buffer
|
||||
|
||||
cmd := exec.Command("ifconfig", name, "destroy")
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Debugf("ifconfig out: %s", stderr.String())
|
||||
|
||||
return fmt.Errorf("destroy %s interface: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Link) setMTU(mtu int) error {
|
||||
var stderr bytes.Buffer
|
||||
|
||||
cmd := exec.Command("ifconfig", l.name, "mtu", strconv.Itoa(mtu))
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Debugf("ifconfig out: %s", stderr.String())
|
||||
|
||||
return fmt.Errorf("set interface mtu: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Link) setAddr(ip, netmask string) error {
|
||||
var stderr bytes.Buffer
|
||||
|
||||
cmd := exec.Command("ifconfig", l.name, "inet", ip, "netmask", netmask)
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Debugf("ifconfig out: %s", stderr.String())
|
||||
|
||||
return fmt.Errorf("set interface addr: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Link) up(name string) error {
|
||||
var stderr bytes.Buffer
|
||||
|
||||
cmd := exec.Command("ifconfig", name, "up")
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Debugf("ifconfig out: %s", stderr.String())
|
||||
|
||||
return fmt.Errorf("up %s interface: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Link) down(name string) error {
|
||||
var stderr bytes.Buffer
|
||||
|
||||
cmd := exec.Command("ifconfig", name, "down")
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
log.Debugf("ifconfig out: %s", stderr.String())
|
||||
|
||||
return fmt.Errorf("down %s interface: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -48,6 +48,19 @@ func (w *WGIface) Address() WGAddress {
|
||||
return w.tun.WgAddress()
|
||||
}
|
||||
|
||||
// ToInterface returns the net.Interface for the Wireguard interface
|
||||
func (r *WGIface) ToInterface() *net.Interface {
|
||||
name := r.tun.DeviceName()
|
||||
intf, err := net.InterfaceByName(name)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get interface by name %s: %v", name, err)
|
||||
intf = &net.Interface{
|
||||
Name: name,
|
||||
}
|
||||
}
|
||||
return intf
|
||||
}
|
||||
|
||||
// Up configures a Wireguard interface
|
||||
// The interface must exist before calling this method (e.g. call interface.Create() before)
|
||||
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
@ -94,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||
return w.configurer.addAllowedIP(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
@ -103,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
|
||||
return w.configurer.removeAllowedIP(peerKey, allowedIP)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
//go:build !android
|
||||
// +build !android
|
||||
|
||||
package iface
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
//go:build !ios
|
||||
// +build !ios
|
||||
|
||||
package iface
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
//go:build ios
|
||||
// +build ios
|
||||
|
||||
package iface
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
//go:build !android
|
||||
// +build !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
@ -43,5 +43,5 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
|
||||
|
||||
// CreateOnAndroid this function make sense on mobile only
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
//go:build !linux || android
|
||||
// +build !linux android
|
||||
//go:build (!linux && !freebsd) || android
|
||||
|
||||
package iface
|
||||
|
||||
|
18
iface/module_freebsd.go
Normal file
18
iface/module_freebsd.go
Normal file
@ -0,0 +1,18 @@
|
||||
package iface
|
||||
|
||||
// WireGuardModuleIsLoaded check if kernel support wireguard
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
|
||||
// we are currently do not use it, since it is required to add wireguard kernel support to
|
||||
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
|
||||
// - https://github.com/mdlayher/socket
|
||||
// TODO: implement kernel space
|
||||
return false
|
||||
}
|
||||
|
||||
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
|
||||
func tunModuleIsLoaded() bool {
|
||||
// Assume tun supported by freebsd kernel by default
|
||||
// TODO: implement check for module loaded in kernel or build-it
|
||||
return true
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
//go:build linux || windows
|
||||
// +build linux windows
|
||||
//go:build linux || windows || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package iface
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
//go:build linux && !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
@ -6,11 +6,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
"github.com/netbirdio/netbird/sharedsock"
|
||||
@ -32,6 +30,8 @@ type tunKernelDevice struct {
|
||||
}
|
||||
|
||||
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
|
||||
checkUser()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &tunKernelDevice{
|
||||
ctx: ctx,
|
||||
@ -48,53 +48,29 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in
|
||||
func (t *tunKernelDevice) Create() (wgConfigurer, error) {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
// check if interface exists
|
||||
l, err := netlink.LinkByName(t.name)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case netlink.LinkNotFoundError:
|
||||
break
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// remove if interface exists
|
||||
if l != nil {
|
||||
err = netlink.LinkDel(link)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("adding device: %s", t.name)
|
||||
err = netlink.LinkAdd(link)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already exists. Will reuse.", t.name)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
if err := link.recreate(); err != nil {
|
||||
return nil, fmt.Errorf("recreate: %w", err)
|
||||
}
|
||||
|
||||
t.link = link
|
||||
|
||||
err = t.assignAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err := t.assignAddr(); err != nil {
|
||||
return nil, fmt.Errorf("assign addr: %w", err)
|
||||
}
|
||||
|
||||
// todo do a discovery
|
||||
// TODO: do a MTU discovery
|
||||
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
|
||||
err = netlink.LinkSetMTU(link, t.mtu)
|
||||
if err != nil {
|
||||
log.Errorf("error setting MTU on interface: %s", t.name)
|
||||
return nil, err
|
||||
|
||||
if err := link.setMTU(t.mtu); err != nil {
|
||||
return nil, fmt.Errorf("set mtu: %w", err)
|
||||
}
|
||||
|
||||
configurer := newWGConfigurer(t.name)
|
||||
err = configurer.configureInterface(t.key, t.wgPort)
|
||||
if err != nil {
|
||||
|
||||
if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return configurer, nil
|
||||
}
|
||||
|
||||
@ -108,9 +84,10 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
||||
}
|
||||
|
||||
log.Debugf("bringing up interface: %s", t.name)
|
||||
err := netlink.LinkSetUp(t.link)
|
||||
if err != nil {
|
||||
|
||||
if err := t.link.up(); err != nil {
|
||||
log.Errorf("error bringing up interface: %s", t.name)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -178,32 +155,5 @@ func (t *tunKernelDevice) Wrapper() *DeviceWrapper {
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *tunKernelDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(link, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 0 {
|
||||
for _, a := range list {
|
||||
addr := a
|
||||
err = netlink.AddrDel(link, &addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", t.address.String(), t.name)
|
||||
addr, _ := netlink.ParseAddr(t.address.String())
|
||||
err = netlink.AddrAdd(link, addr)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", t.name, t.address.String())
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// On linux, the link must be brought up
|
||||
err = netlink.LinkSetUp(link)
|
||||
return err
|
||||
return t.link.assignAddr(t.address)
|
||||
}
|
80
iface/tun_link_freebsd.go
Normal file
80
iface/tun_link_freebsd.go
Normal file
@ -0,0 +1,80 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/iface/freebsd"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
name string
|
||||
link *freebsd.Link
|
||||
}
|
||||
|
||||
func newWGLink(name string) *wgLink {
|
||||
link := freebsd.NewLink(name)
|
||||
|
||||
return &wgLink{
|
||||
name: name,
|
||||
link: link,
|
||||
}
|
||||
}
|
||||
|
||||
// Type returns the interface type
|
||||
func (l *wgLink) Type() string {
|
||||
return "wireguard"
|
||||
}
|
||||
|
||||
// Close deletes the link interface
|
||||
func (l *wgLink) Close() error {
|
||||
return l.link.Del()
|
||||
}
|
||||
|
||||
func (l *wgLink) recreate() error {
|
||||
if err := l.link.Recreate(); err != nil {
|
||||
return fmt.Errorf("recreate: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) setMTU(mtu int) error {
|
||||
if err := l.link.SetMTU(mtu); err != nil {
|
||||
return fmt.Errorf("set mtu: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) up() error {
|
||||
if err := l.link.Up(); err != nil {
|
||||
return fmt.Errorf("up: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
link, err := freebsd.LinkByName(l.name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
}
|
||||
|
||||
ip := address.IP.String()
|
||||
mask := "0x" + address.Network.Mask.String()
|
||||
|
||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||
|
||||
err = link.AssignAddr(ip, mask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("assign addr: %w", err)
|
||||
}
|
||||
|
||||
err = link.Up()
|
||||
if err != nil {
|
||||
return fmt.Errorf("up: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -2,7 +2,13 @@
|
||||
|
||||
package iface
|
||||
|
||||
import "github.com/vishvananda/netlink"
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
type wgLink struct {
|
||||
attrs *netlink.LinkAttrs
|
||||
@ -31,3 +37,97 @@ func (l *wgLink) Type() string {
|
||||
func (l *wgLink) Close() error {
|
||||
return netlink.LinkDel(l)
|
||||
}
|
||||
|
||||
func (l *wgLink) recreate() error {
|
||||
name := l.attrs.Name
|
||||
|
||||
// check if interface exists
|
||||
link, err := netlink.LinkByName(name)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case netlink.LinkNotFoundError:
|
||||
break
|
||||
default:
|
||||
return fmt.Errorf("link by name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// remove if interface exists
|
||||
if link != nil {
|
||||
err = netlink.LinkDel(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("adding device: %s", name)
|
||||
err = netlink.LinkAdd(l)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already exists. Will reuse.", name)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("link add: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) setMTU(mtu int) error {
|
||||
if err := netlink.LinkSetMTU(l, mtu); err != nil {
|
||||
log.Errorf("error setting MTU on interface: %s", l.attrs.Name)
|
||||
|
||||
return fmt.Errorf("link set mtu: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) up() error {
|
||||
if err := netlink.LinkSetUp(l); err != nil {
|
||||
log.Errorf("error bringing up interface: %s", l.attrs.Name)
|
||||
return fmt.Errorf("link setup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(l, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list addr: %w", err)
|
||||
}
|
||||
|
||||
if len(list) > 0 {
|
||||
for _, a := range list {
|
||||
addr := a
|
||||
err = netlink.AddrDel(l, &addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("del addr: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
name := l.attrs.Name
|
||||
addrStr := address.String()
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", addrStr, name)
|
||||
|
||||
addr, err := netlink.ParseAddr(addrStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse addr: %w", err)
|
||||
}
|
||||
|
||||
err = netlink.AddrAdd(l, addr)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", name, addrStr)
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("add addr: %w", err)
|
||||
}
|
||||
|
||||
// On linux, the link must be brought up
|
||||
if err := netlink.LinkSetUp(l); err != nil {
|
||||
return fmt.Errorf("link setup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,14 +1,14 @@
|
||||
//go:build linux && !android
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
|
||||
@ -31,6 +31,9 @@ type tunUSPDevice struct {
|
||||
|
||||
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
|
||||
log.Infof("using userspace bind mode")
|
||||
|
||||
checkUser()
|
||||
|
||||
return &tunUSPDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@ -129,30 +132,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
|
||||
func (t *tunUSPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
//delete existing addresses
|
||||
list, err := netlink.AddrList(link, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 0 {
|
||||
for _, a := range list {
|
||||
addr := a
|
||||
err = netlink.AddrDel(link, &addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return link.assignAddr(t.address)
|
||||
}
|
||||
|
||||
func checkUser() {
|
||||
if runtime.GOOS == "freebsd" {
|
||||
euid := os.Geteuid()
|
||||
if euid != 0 {
|
||||
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("adding address %s to interface: %s", t.address.String(), t.name)
|
||||
addr, _ := netlink.ParseAddr(t.address.String())
|
||||
err = netlink.AddrAdd(link, addr)
|
||||
if os.IsExist(err) {
|
||||
log.Infof("interface %s already has the address: %s", t.name, t.address.String())
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// On linux, the link must be brought up
|
||||
err = netlink.LinkSetUp(link)
|
||||
return err
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user