Merge branch 'main' into fix/delete-route-only-after-adding

This commit is contained in:
Maycon Santos 2024-04-12 18:45:02 +02:00
commit bb08f3677f
2 changed files with 115 additions and 18 deletions

View File

@ -9,6 +9,8 @@ import (
"net" "net"
"net/netip" "net/netip"
"os" "os"
"strconv"
"strings"
"syscall" "syscall"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@ -30,14 +32,26 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables" rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting. // ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
) )
var ErrTableIDExists = errors.New("ID exists with different name") var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{} var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct { type ruleParams struct {
priority int priority int
fwmark int fwmark int
@ -77,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err) log.Errorf("Error adding routing table name: %v", err)
} }
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() { defer func() {
if err != nil { if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil { if cleanErr := cleanupRouting(); cleanErr != nil {
@ -124,6 +145,12 @@ func cleanupRouting() error {
} }
} }
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil() return result.ErrorOrNil()
} }
@ -140,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf) return genericAddVPNRoute(prefix, intf)
} }
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support // TODO remove this once we have ipv6 support
@ -332,22 +363,8 @@ func flushRoutes(tableID, family int) error {
} }
func enableIPForwarding() error { func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath) _, err := setSysctl(ipv4ForwardingPath, 1, false)
if err != nil { return err
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
} }
// entryExists checks if the specified ID or name already exists in the rt_tables file // entryExists checks if the specified ID or name already exists in the rt_tables file
@ -475,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
} }
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create() err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface") require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil) _, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, cleanupRouting()) assert.NoError(t, cleanupRouting())