mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-19 17:31:39 +02:00
[client] Use platform-native routing APIs for freeBSD, macOS and Windows
This commit is contained in:
parent
87148c503f
commit
ea4d13e96d
@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
|||||||
Example: `
|
Example: `
|
||||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
Args: cobra.ExactArgs(3),
|
Args: cobra.ExactArgs(3),
|
||||||
RunE: tracePacket,
|
RunE: tracePacket,
|
||||||
|
@ -2,7 +2,7 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
expectedExprs2 := []expr.Any{
|
expectedExprs2 := []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: add.AsSlice(),
|
Data: ip.AsSlice(),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
@ -41,7 +41,7 @@ type Forwarder struct {
|
|||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip net.IP
|
ip tcpip.Address
|
||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ones, _ := iface.Address().Network.Mask.Size()
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
PrefixLen: ones,
|
PrefixLen: iface.Address().Network.Bits(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: iface.Address().IP,
|
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
}
|
}
|
||||||
|
|
||||||
receiveWindow := defaultReceiveWindow
|
receiveWindow := defaultReceiveWindow
|
||||||
@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
if f.netstack && f.ip.Equal(addr) {
|
||||||
return net.IPv4(127, 0, 0, 1)
|
return net.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
return addr.AsSlice()
|
return addr.AsSlice()
|
||||||
@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
|
|
||||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
return value.([]byte), true
|
return value.([]byte), true
|
||||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||||
|
@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if !ip.Is4() {
|
||||||
high := uint16(ipv4[0])
|
return
|
||||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
}
|
||||||
|
ipv4 := ip.AsSlice()
|
||||||
|
|
||||||
if bitmap[high] == nil {
|
high := uint16(ipv4[0])
|
||||||
bitmap[high] = &ipv4LowBitmap{}
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
}
|
|
||||||
|
|
||||||
index := low / 32
|
if bitmap[high] == nil {
|
||||||
bit := low % 32
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
bitmap[high].bitmap[index] |= 1 << bit
|
}
|
||||||
|
|
||||||
ipStr := ipv4.String()
|
index := low / 32
|
||||||
if _, exists := ipv4Set[ipStr]; !exists {
|
bit := low % 32
|
||||||
ipv4Set[ipStr] = struct{}{}
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
|
||||||
}
|
if _, exists := ipv4Set[ip]; !exists {
|
||||||
|
ipv4Set[ip] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
|||||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[string]struct{})
|
ipv4Set := make(map[netip.Addr]struct{})
|
||||||
var ipv4Addresses []string
|
var ipv4Addresses []netip.Addr
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
|
@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match - addresses 32 apart",
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("fe80::1"),
|
IP: netip.MustParseAddr("fe80::1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("fe80::"),
|
|
||||||
Mask: net.CIDRMask(64, 128),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
|
@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -71,7 +71,6 @@ type Manager struct {
|
|||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@ -1091,11 +1090,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
|
||||||
m.wgNetwork = network
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
|
@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply scenario-specific setup
|
// Apply scenario-specific setup
|
||||||
sc.setupFunc(manager)
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-populate connection table
|
// Pre-populate connection table
|
||||||
srcIPs := generateRandomIPs(count)
|
srcIPs := generateRandomIPs(count)
|
||||||
dstIPs := generateRandomIPs(count)
|
dstIPs := generateRandomIPs(count)
|
||||||
@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP := generateRandomIPs(1)[0]
|
srcIP := generateRandomIPs(1)[0]
|
||||||
dstIP := generateRandomIPs(1)[0]
|
dstIP := generateRandomIPs(1)[0]
|
||||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "post_handshake",
|
state: "post_handshake",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
dst := fw.Network{Prefix: r.dest}
|
||||||
|
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -19,12 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
localIP := net.ParseIP("100.10.0.100")
|
localIP := netip.MustParseAddr("100.10.0.100")
|
||||||
wgNet := &net.IPNet{
|
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = wgNet
|
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
dev := mocks.NewMockDevice(ctrl)
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
localIP, wgNet, err := net.ParseCIDR(network)
|
wgNet := netip.MustParsePrefix(network)
|
||||||
require.NoError(tb, err)
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: wgNet.Addr(),
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
|
@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a.AsSlice()) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -24,9 +23,6 @@ type PacketFilter interface {
|
|||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
SetNetwork(*net.IPNet)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
|
@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
log.Info("create nbnetstack tun interface")
|
log.Info("create nbnetstack tun interface")
|
||||||
|
|
||||||
// TODO: get from service listener runtime IP
|
// TODO: get from service listener runtime IP
|
||||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("last ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("netstack using address: %s", t.address.IP)
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
|
@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ip := address.IP.String()
|
ip := address.IP.String()
|
||||||
mask := "0x" + address.Network.Mask.String()
|
|
||||||
|
// Convert prefix length to hex netmask
|
||||||
|
prefixLen := address.Network.Bits()
|
||||||
|
if !address.IP.Is4() {
|
||||||
|
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||||
|
}
|
||||||
|
|
||||||
|
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||||
|
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||||
|
|
||||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||||
|
|
||||||
|
@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.filter = filter
|
w.filter = filter
|
||||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
|
||||||
|
|
||||||
w.tun.FilteredDevice().SetFilter(filter)
|
w.tun.FilteredDevice().SetFilter(filter)
|
||||||
return nil
|
return nil
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
|
||||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork indicates an expected call of SetNetwork.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
|
||||||
}
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -15,8 +13,8 @@ import (
|
|||||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
address net.IP
|
address netip.Addr
|
||||||
dnsAddress net.IP
|
dnsAddress netip.Addr
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
|
||||||
@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
address: address,
|
address: address,
|
||||||
dnsAddress: dnsAddress,
|
dnsAddress: dnsAddress,
|
||||||
@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
addr, ok := netip.AddrFromSlice(t.address)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
[]netip.Addr{addr.Unmap()},
|
[]netip.Addr{t.address},
|
||||||
[]netip.Addr{dnsAddr.Unmap()},
|
[]netip.Addr{t.dnsAddress},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -2,28 +2,27 @@ package wgaddr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type Address struct {
|
type Address struct {
|
||||||
IP net.IP
|
IP netip.Addr
|
||||||
Network *net.IPNet
|
Network netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (Address, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
ip, network, err := net.ParseCIDR(address)
|
prefix, err := netip.ParsePrefix(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Address{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return Address{
|
return Address{
|
||||||
IP: ip,
|
IP: prefix.Addr().Unmap(),
|
||||||
Network: network,
|
Network: prefix.Masked(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr Address) String() string {
|
func (addr Address) String() string {
|
||||||
maskSize, _ := addr.Network.Mask.Size()
|
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
@ -43,12 +43,11 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
@ -162,12 +161,11 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
@ -372,12 +370,11 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
@ -2,7 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -12,13 +12,14 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||||
ip := net.ParseIP(aRecord.RData)
|
ip, err := netip.ParseAddr(aRecord.RData)
|
||||||
if ip == nil || ip.To4() == nil {
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ipNet.Contains(ip) {
|
if !prefix.Contains(ip) {
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
networkIP := network.Masked().Addr()
|
||||||
maskOnes, _ := ipNet.Mask.Size()
|
|
||||||
|
if !networkIP.Is4() {
|
||||||
|
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||||
|
}
|
||||||
|
|
||||||
// round up to nearest byte
|
// round up to nearest byte
|
||||||
octetsToUse := (maskOnes + 7) / 8
|
octetsToUse := (network.Bits() + 7) / 8
|
||||||
|
|
||||||
octets := strings.Split(networkIP.String(), ".")
|
octets := strings.Split(networkIP.String(), ".")
|
||||||
if octetsToUse > len(octets) {
|
if octetsToUse > len(octets) {
|
||||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||||
}
|
}
|
||||||
|
|
||||||
reverseOctets := make([]string, octetsToUse)
|
reverseOctets := make([]string, octetsToUse)
|
||||||
@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
|
||||||
records = append(records, ptrRecord)
|
records = append(records, ptrRecord)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||||
zoneName, err := generateReverseZoneName(ipNet)
|
zoneName, err := generateReverseZoneName(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err)
|
log.Warn(err)
|
||||||
return
|
return
|
||||||
@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
records := collectPTRRecords(config, ipNet)
|
records := collectPTRRecords(config, network)
|
||||||
|
|
||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
|
@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() wgaddr.Address {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: ip,
|
IP: netip.MustParseAddr("100.66.100.1"),
|
||||||
Network: network,
|
Network: netip.MustParsePrefix("100.66.100.0/24"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("parse CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
|
||||||
|
|
||||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||||
t.Errorf("set packet filter: %v", err)
|
t.Errorf("set packet filter: %v", err)
|
||||||
|
@ -24,11 +24,15 @@ type ServiceViaMemory struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||||
|
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("get last ip from network: %v", err)
|
||||||
|
}
|
||||||
s := &ServiceViaMemory{
|
s := &ServiceViaMemory{
|
||||||
wgInterface: wgIface,
|
wgInterface: wgIface,
|
||||||
dnsMux: dns.NewServeMux(),
|
dnsMux: dns.NewServeMux(),
|
||||||
|
|
||||||
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(),
|
runtimeIP: lastIP.String(),
|
||||||
runtimePort: defaultPort,
|
runtimePort: defaultPort,
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
firstLayerDecoder := layers.LayerTypeIPv4
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
if s.wgInterface.Address().IP.Is6() {
|
||||||
firstLayerDecoder = layers.LayerTypeIPv6
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
addr string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
|
||||||
{"192.168.0.0/30", "192.168.0.2"},
|
|
||||||
{"192.168.0.0/16", "192.168.255.254"},
|
|
||||||
{"192.168.0.0/24", "192.168.0.254"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error parsing CIDR: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
|
|
||||||
if lastIP != tt.ip {
|
|
||||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -3,6 +3,7 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -23,8 +24,8 @@ type upstreamResolver struct {
|
|||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
_ string,
|
_ string,
|
||||||
_ net.IP,
|
_ netip.Addr,
|
||||||
_ *net.IPNet,
|
_ netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
hostsDNSHolder *hostsDNSHolder,
|
hostsDNSHolder *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
|
@ -4,7 +4,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@ -19,8 +19,8 @@ type upstreamResolver struct {
|
|||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
_ string,
|
_ string,
|
||||||
_ net.IP,
|
_ netip.Addr,
|
||||||
_ *net.IPNet,
|
_ netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -18,16 +19,16 @@ import (
|
|||||||
|
|
||||||
type upstreamResolverIOS struct {
|
type upstreamResolverIOS struct {
|
||||||
*upstreamResolverBase
|
*upstreamResolverBase
|
||||||
lIP net.IP
|
lIP netip.Addr
|
||||||
lNet *net.IPNet
|
lNet netip.Prefix
|
||||||
interfaceName string
|
interfaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpstreamResolver(
|
func newUpstreamResolver(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
interfaceName string,
|
interfaceName string,
|
||||||
ip net.IP,
|
ip netip.Addr,
|
||||||
net *net.IPNet,
|
net netip.Prefix,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
_ *hostsDNSHolder,
|
_ *hostsDNSHolder,
|
||||||
domain string,
|
domain string,
|
||||||
@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
client.DialTimeout = timeout
|
client.DialTimeout = timeout
|
||||||
|
|
||||||
upstreamIP := net.ParseIP(upstreamHost)
|
upstreamIP, err := netip.ParseAddr(upstreamHost)
|
||||||
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
|
||||||
|
}
|
||||||
|
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
log.Debugf("using private client to query upstream: %s", upstream)
|
||||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
// This method is needed for iOS
|
// This method is needed for iOS
|
||||||
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
|
||||||
index, err := getInterfaceIndex(interfaceName)
|
index, err := getInterfaceIndex(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
|
||||||
@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
|
|||||||
|
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
LocalAddr: &net.UDPAddr{
|
LocalAddr: &net.UDPAddr{
|
||||||
IP: ip,
|
IP: ip.AsSlice(),
|
||||||
Port: 0, // Let the OS pick a free port
|
Port: 0, // Let the OS pick a free port
|
||||||
},
|
},
|
||||||
Timeout: dialTimeout,
|
Timeout: dialTimeout,
|
||||||
|
@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.TODO())
|
ctx, cancel := context.WithCancel(context.TODO())
|
||||||
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".")
|
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||||
resolver.upstreamServers = testCase.InputServers
|
resolver.upstreamServers = testCase.InputServers
|
||||||
resolver.upstreamTimeout = testCase.timeout
|
resolver.upstreamTimeout = testCase.timeout
|
||||||
if testCase.cancelCTX {
|
if testCase.cancelCTX {
|
||||||
|
@ -1008,7 +1008,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
// apply routes first, route related actions might depend on routing being enabled
|
// apply routes first, route related actions might depend on routing being enabled
|
||||||
routes := toRoutes(networkMap.GetRoutes())
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update routes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.acl != nil {
|
if e.acl != nil {
|
||||||
@ -1104,7 +1104,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
|
|
||||||
convertedRoute := &route.Route{
|
convertedRoute := &route.Route{
|
||||||
ID: route.ID(protoRoute.ID),
|
ID: route.ID(protoRoute.ID),
|
||||||
Network: prefix,
|
Network: prefix.Masked(),
|
||||||
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
Domains: domain.FromPunycodeList(protoRoute.Domains),
|
||||||
NetID: route.NetID(protoRoute.NetID),
|
NetID: route.NetID(protoRoute.NetID),
|
||||||
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
NetworkType: route.NetworkType(protoRoute.NetworkType),
|
||||||
@ -1138,7 +1138,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
|||||||
return entries
|
return entries
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||||
dnsUpdate := nbdns.Config{
|
dnsUpdate := nbdns.Config{
|
||||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||||
CustomZones: make([]nbdns.CustomZone, 0),
|
CustomZones: make([]nbdns.CustomZone, 0),
|
||||||
@ -1790,9 +1790,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetWgAddr returns the wireguard address
|
// GetWgAddr returns the wireguard address
|
||||||
func (e *Engine) GetWgAddr() net.IP {
|
func (e *Engine) GetWgAddr() netip.Addr {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return nil
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
return e.wgInterface.Address().IP
|
return e.wgInterface.Address().IP
|
||||||
}
|
}
|
||||||
@ -1861,12 +1861,7 @@ func (e *Engine) Address() (netip.Addr, error) {
|
|||||||
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
return netip.Addr{}, errors.New("wireguard interface not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := e.wgInterface.Address()
|
return e.wgInterface.Address().IP, nil
|
||||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
|
|
||||||
}
|
|
||||||
return ip.Unmap(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||||
|
@ -371,11 +371,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
|
@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
|||||||
|
|
||||||
// fallback if mark rules are not in place
|
// fallback if mark rules are not in place
|
||||||
wgnet := c.iface.Address().Network
|
wgnet := c.iface.Address().Network
|
||||||
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
|
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapRxPackets maps packet counts to RX based on flow direction
|
// mapRxPackets maps packet counts to RX based on flow direction
|
||||||
@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
|
|||||||
// fallback if marks are not set
|
// fallback if marks are not set
|
||||||
wgaddr := c.iface.Address().IP
|
wgaddr := c.iface.Address().IP
|
||||||
wgnetwork := c.iface.Address().Network
|
wgnetwork := c.iface.Address().Network
|
||||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case wgaddr.Equal(src):
|
case wgaddr == srcIP:
|
||||||
return nftypes.Egress
|
return nftypes.Egress
|
||||||
case wgaddr.Equal(dst):
|
case wgaddr == dstIP:
|
||||||
return nftypes.Ingress
|
return nftypes.Ingress
|
||||||
case wgnetwork.Contains(src):
|
case wgnetwork.Contains(srcIP):
|
||||||
// netbird network -> resource network
|
// netbird network -> resource network
|
||||||
return nftypes.Ingress
|
return nftypes.Ingress
|
||||||
case wgnetwork.Contains(dst):
|
case wgnetwork.Contains(dstIP):
|
||||||
// resource network -> netbird network
|
// resource network -> netbird network
|
||||||
return nftypes.Egress
|
return nftypes.Egress
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ package logger
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -23,17 +23,16 @@ type Logger struct {
|
|||||||
rcvChan atomic.Pointer[rcvChan]
|
rcvChan atomic.Pointer[rcvChan]
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
wgIfaceIPNet net.IPNet
|
wgIfaceNet netip.Prefix
|
||||||
dnsCollection atomic.Bool
|
dnsCollection atomic.Bool
|
||||||
exitNodeCollection atomic.Bool
|
exitNodeCollection atomic.Bool
|
||||||
Store types.Store
|
Store types.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
|
||||||
|
|
||||||
return &Logger{
|
return &Logger{
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgIfaceIPNet: wgIfaceIPNet,
|
wgIfaceNet: wgIfaceIPNet,
|
||||||
Store: store.NewMemoryStore(),
|
Store: store.NewMemoryStore(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
|
|||||||
var isSrcExitNode bool
|
var isSrcExitNode bool
|
||||||
var isDestExitNode bool
|
var isDestExitNode bool
|
||||||
|
|
||||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
if !l.wgIfaceNet.Contains(event.SourceIP) {
|
||||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
if !l.wgIfaceNet.Contains(event.DestIP) {
|
||||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package logger_test
|
package logger_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
logger := logger.New(nil, net.IPNet{})
|
logger := logger.New(nil, netip.Prefix{})
|
||||||
logger.Enable()
|
logger.Enable()
|
||||||
|
|
||||||
event := types.EventFields{
|
event := types.EventFields{
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -34,11 +34,11 @@ type Manager struct {
|
|||||||
|
|
||||||
// NewManager creates a new netflow manager
|
// NewManager creates a new netflow manager
|
||||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||||
var ipNet net.IPNet
|
var prefix netip.Prefix
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
ipNet = *iface.Address().Network
|
prefix = iface.Address().Network
|
||||||
}
|
}
|
||||||
flowLogger := logger.New(statusRecorder, ipNet)
|
flowLogger := logger.New(statusRecorder, prefix)
|
||||||
|
|
||||||
var ct nftypes.ConnTracker
|
var ct nftypes.ConnTracker
|
||||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package netflow
|
package netflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -33,10 +33,7 @@ func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
|||||||
func TestManager_Update(t *testing.T) {
|
func TestManager_Update(t *testing.T) {
|
||||||
mockIFace := &mockIFaceMapper{
|
mockIFace := &mockIFaceMapper{
|
||||||
address: wgaddr.Address{
|
address: wgaddr.Address{
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
isUserspaceBind: true,
|
isUserspaceBind: true,
|
||||||
}
|
}
|
||||||
@ -102,10 +99,7 @@ func TestManager_Update(t *testing.T) {
|
|||||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||||
mockIFace := &mockIFaceMapper{
|
mockIFace := &mockIFaceMapper{
|
||||||
address: wgaddr.Address{
|
address: wgaddr.Address{
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.1/32"),
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
isUserspaceBind: true,
|
isUserspaceBind: true,
|
||||||
}
|
}
|
||||||
|
@ -264,7 +264,7 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(ip, ip.BitLen())
|
prefix := netip.PrefixFrom(ip.Unmap(), ip.BitLen())
|
||||||
newPrefixes = append(newPrefixes, prefix)
|
newPrefixes = append(newPrefixes, prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,11 +333,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
|
|
||||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
if !m.disableClientRoutes {
|
if !m.disableClientRoutes {
|
||||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||||
|
|
||||||
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
|
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
|
||||||
log.Errorf("Failed to update system routes: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||||
@ -346,14 +347,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
m.clientRoutes = newClientRoutesIDMap
|
m.clientRoutes = newClientRoutesIDMap
|
||||||
|
|
||||||
if m.serverRouter == nil {
|
if m.serverRouter == nil {
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
if err := m.serverRouter.UpdateRoutes(newServerRoutesMap, useNewDNSRoute); err != nil {
|
||||||
return fmt.Errorf("update routes: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("update server routes: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRouteChangeListener set RouteListener for route change Notifier
|
// SetRouteChangeListener set RouteListener for route change Notifier
|
||||||
|
@ -44,7 +44,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -71,7 +71,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: localPeerKey,
|
Peer: localPeerKey,
|
||||||
Network: netip.MustParsePrefix("100.64.252.250/30"),
|
Network: netip.MustParsePrefix("100.64.252.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -99,7 +99,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: localPeerKey,
|
Peer: localPeerKey,
|
||||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
Network: netip.MustParsePrefix("100.64.30.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -127,7 +127,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: localPeerKey,
|
Peer: localPeerKey,
|
||||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
Network: netip.MustParsePrefix("100.64.30.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -211,7 +211,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -233,7 +233,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -250,7 +250,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -272,7 +272,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -282,7 +282,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "b",
|
ID: "b",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey2,
|
Peer: remotePeerKey2,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -299,7 +299,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -327,7 +327,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "a",
|
ID: "a",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: localPeerKey,
|
Peer: localPeerKey,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -356,7 +356,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "l1",
|
ID: "l1",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: localPeerKey,
|
Peer: localPeerKey,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -376,7 +376,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
ID: "r1",
|
ID: "r1",
|
||||||
NetID: "routeA",
|
NetID: "routeA",
|
||||||
Peer: remotePeerKey1,
|
Peer: remotePeerKey1,
|
||||||
Network: netip.MustParsePrefix("100.64.251.250/30"),
|
Network: netip.MustParsePrefix("100.64.251.248/30"),
|
||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
Metric: 9999,
|
Metric: 9999,
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
@ -440,11 +440,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(testCase.inputInitRoutes) > 0 {
|
if len(testCase.inputInitRoutes) > 0 {
|
||||||
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
|
||||||
require.NoError(t, err, "should update routes with init routes")
|
require.NoError(t, err, "should update routes with init routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -22,8 +22,13 @@ const (
|
|||||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type iface interface {
|
||||||
|
Address() wgaddr.Address
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
// Setup configures sysctl settings for RP filtering and source validation.
|
// Setup configures sysctl settings for RP filtering and source validation.
|
||||||
func Setup(wgIface iface.WGIface) (map[string]int, error) {
|
func Setup(wgIface iface) (map[string]int, error) {
|
||||||
keys := map[string]int{}
|
keys := map[string]int{}
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
|
@ -6,9 +6,10 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Nexthop struct {
|
type Nexthop struct {
|
||||||
@ -30,11 +31,16 @@ func (n Nexthop) String() string {
|
|||||||
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type wgIface interface {
|
||||||
|
Address() wgaddr.Address
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
|
||||||
|
|
||||||
type SysOps struct {
|
type SysOps struct {
|
||||||
refCounter *ExclusionCounter
|
refCounter *ExclusionCounter
|
||||||
wgInterface iface.WGIface
|
wgInterface wgIface
|
||||||
// prefixes is tracking all the current added prefixes im memory
|
// prefixes is tracking all the current added prefixes im memory
|
||||||
// (this is used in iOS as all route updates require a full table update)
|
// (this is used in iOS as all route updates require a full table update)
|
||||||
//nolint
|
//nolint
|
||||||
@ -45,9 +51,27 @@ type SysOps struct {
|
|||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
|
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
return &SysOps{
|
return &SysOps{
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
|
||||||
|
addr := prefix.Addr()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case
|
||||||
|
!addr.IsValid(),
|
||||||
|
addr.IsLoopback(),
|
||||||
|
addr.IsLinkLocalUnicast(),
|
||||||
|
addr.IsLinkLocalMulticast(),
|
||||||
|
addr.IsInterfaceLocalMulticast(),
|
||||||
|
addr.IsMulticast(),
|
||||||
|
addr.IsUnspecified() && prefix.Bits() != 0,
|
||||||
|
r.wgInterface.Address().Network.Contains(addr):
|
||||||
|
return vars.ErrRouteNotAllowed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -33,7 +35,12 @@ func init() {
|
|||||||
|
|
||||||
func TestConcurrentRoutes(t *testing.T) {
|
func TestConcurrentRoutes(t *testing.T) {
|
||||||
baseIP := netip.MustParseAddr("192.0.2.0")
|
baseIP := netip.MustParseAddr("192.0.2.0")
|
||||||
intf := &net.Interface{Name: "lo0"}
|
|
||||||
|
var intf *net.Interface
|
||||||
|
var nexthop Nexthop
|
||||||
|
|
||||||
|
_, intf = setupDummyInterface(t)
|
||||||
|
nexthop = Nexthop{netip.Addr{}, intf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := NewSysOps(nil, nil)
|
||||||
|
|
||||||
@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
|||||||
go func(ip netip.Addr) {
|
go func(ip netip.Addr) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
prefix := netip.PrefixFrom(ip, 32)
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
if err := r.addToRouteTable(prefix, nexthop); err != nil {
|
||||||
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
t.Errorf("Failed to add route for %s: %v", prefix, err)
|
||||||
}
|
}
|
||||||
}(baseIP)
|
}(baseIP)
|
||||||
@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
|||||||
go func(ip netip.Addr) {
|
go func(ip netip.Addr) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
prefix := netip.PrefixFrom(ip, 32)
|
prefix := netip.PrefixFrom(ip, 32)
|
||||||
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
|
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
|
||||||
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
t.Errorf("Failed to remove route for %s: %v", prefix, err)
|
||||||
}
|
}
|
||||||
}(baseIP)
|
}(baseIP)
|
||||||
@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
|
|||||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
if runtime.GOOS == "darwin" {
|
||||||
require.NoError(t, err, "Failed to create loopback alias")
|
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
||||||
|
require.NoError(t, err, "Failed to create loopback alias")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
||||||
|
assert.NoError(t, err, "Failed to remove loopback alias")
|
||||||
|
})
|
||||||
|
|
||||||
|
return intf
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := netip.ParsePrefix(ipAddressCIDR)
|
||||||
|
require.NoError(t, err, "Failed to parse prefix")
|
||||||
|
|
||||||
|
netIntf, err := net.InterfaceByName(intf)
|
||||||
|
require.NoError(t, err, "Failed to get interface by name")
|
||||||
|
|
||||||
|
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||||
|
|
||||||
|
r := NewSysOps(nil, nil)
|
||||||
|
err = r.addToRouteTable(prefix, nexthop)
|
||||||
|
require.NoError(t, err, "Failed to add route to table")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
err := r.removeFromRouteTable(prefix, nexthop)
|
||||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
assert.NoError(t, err, "Failed to remove route from table")
|
||||||
})
|
})
|
||||||
|
|
||||||
return "lo0"
|
return intf
|
||||||
}
|
}
|
||||||
|
|
||||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
var originalNexthop net.IP
|
var originalNexthop net.IP
|
||||||
@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
|
|||||||
return net.ParseIP(matches[1]), nil
|
return net.ParseIP(matches[1]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
|
||||||
|
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
|
||||||
|
|
||||||
|
tunName := strings.TrimSpace(string(output))
|
||||||
|
|
||||||
|
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
|
||||||
|
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
|
||||||
|
|
||||||
|
intf, err := net.InterfaceByName(tunName)
|
||||||
|
require.NoError(t, err, "Failed to get interface by name")
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
|
||||||
|
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
|
||||||
|
}
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
||||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
|
||||||
|
|
||||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
||||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,6 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: fix: for default our wg address now appears as the default gw
|
|
||||||
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
||||||
addr := netip.IPv4Unspecified()
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
addr = netip.IPv6Unspecified()
|
|
||||||
}
|
|
||||||
|
|
||||||
nexthop, err := GetNextHop(addr)
|
|
||||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
|
||||||
return fmt.Errorf("get existing route gateway: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !prefix.Contains(nexthop.IP) {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
|
|
||||||
if nexthop.IP.Is6() {
|
|
||||||
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := existsInRouteTable(gatewayPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nexthop, err = GetNextHop(nexthop.IP)
|
|
||||||
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
|
|
||||||
return r.addToRouteTable(gatewayPrefix, nexthop)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
||||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
||||||
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
|
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
|
||||||
addr := prefix.Addr()
|
if err := r.validateRoute(prefix); err != nil {
|
||||||
switch {
|
return Nexthop{}, err
|
||||||
case addr.IsLoopback(),
|
}
|
||||||
addr.IsLinkLocalUnicast(),
|
|
||||||
addr.IsLinkLocalMulticast(),
|
|
||||||
addr.IsInterfaceLocalMulticast(),
|
|
||||||
addr.IsUnspecified(),
|
|
||||||
addr.IsMulticast():
|
|
||||||
|
|
||||||
|
addr := prefix.Addr()
|
||||||
|
if addr.IsUnspecified() {
|
||||||
return Nexthop{}, vars.ErrRouteNotAllowed
|
return Nexthop{}, vars.ErrRouteNotAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
|
|||||||
Intf: nexthop.Intf,
|
Intf: nexthop.Intf,
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
vpnAddr := vpnIntf.Address().IP
|
||||||
if !ok {
|
|
||||||
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||||
@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.addNonExistingRoute(prefix, intf)
|
return r.addToRouteTable(prefix, nextHop)
|
||||||
}
|
|
||||||
|
|
||||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
|
||||||
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|
||||||
ok, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("exists in route table: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sub range: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
|
|
||||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
||||||
@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
|||||||
|
|
||||||
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
||||||
if gateway == nil {
|
if gateway == nil {
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
return Nexthop{Intf: intf}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if preferredSrc == nil {
|
if preferredSrc == nil {
|
||||||
return Nexthop{}, vars.ErrRouteNotFound
|
return Nexthop{Intf: intf}, nil
|
||||||
}
|
}
|
||||||
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
|
||||||
|
|
||||||
@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
|||||||
return addr.Unmap(), nil
|
return addr.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := GetRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute == prefix {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := GetRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||||
localRoutes, err := hasSeparateRouting()
|
localRoutes, err := hasSeparateRouting()
|
||||||
|
@ -3,23 +3,25 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
type dialer interface {
|
||||||
@ -27,105 +29,370 @@ type dialer interface {
|
|||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddRemoveRoutes(t *testing.T) {
|
func TestAddVPNRoute(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
prefix netip.Prefix
|
prefix netip.Prefix
|
||||||
shouldRouteToWireguard bool
|
expectError bool
|
||||||
shouldBeRemoved bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should Add And Remove Route 100.66.120.0/24",
|
name: "IPv4 - Private network route",
|
||||||
prefix: netip.MustParsePrefix("100.66.120.0/24"),
|
prefix: netip.MustParsePrefix("10.10.100.0/24"),
|
||||||
shouldRouteToWireguard: true,
|
|
||||||
shouldBeRemoved: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Should Not Add Or Remove Route 127.0.0.1/32",
|
name: "IPv4 Single host",
|
||||||
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
prefix: netip.MustParsePrefix("10.111.111.111/32"),
|
||||||
shouldRouteToWireguard: false,
|
},
|
||||||
shouldBeRemoved: false,
|
{
|
||||||
|
name: "IPv4 RFC3927 test range",
|
||||||
|
prefix: netip.MustParsePrefix("198.51.100.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Default route",
|
||||||
|
prefix: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "IPv6 Subnet",
|
||||||
|
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Single host",
|
||||||
|
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Default route",
|
||||||
|
prefix: netip.MustParsePrefix("::/0"),
|
||||||
|
},
|
||||||
|
|
||||||
|
// IPv4 addresses that should be rejected (matches validateRoute logic)
|
||||||
|
{
|
||||||
|
name: "IPv4 Loopback",
|
||||||
|
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Link-local unicast",
|
||||||
|
prefix: netip.MustParsePrefix("169.254.1.1/32"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Link-local multicast",
|
||||||
|
prefix: netip.MustParsePrefix("224.0.0.251/32"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Multicast",
|
||||||
|
prefix: netip.MustParsePrefix("239.255.255.250/32"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Unspecified with prefix",
|
||||||
|
prefix: netip.MustParsePrefix("0.0.0.0/32"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// IPv6 addresses that should be rejected (matches validateRoute logic)
|
||||||
|
{
|
||||||
|
name: "IPv6 Loopback",
|
||||||
|
prefix: netip.MustParsePrefix("::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Link-local unicast",
|
||||||
|
prefix: netip.MustParsePrefix("fe80::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Link-local multicast",
|
||||||
|
prefix: netip.MustParsePrefix("ff02::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Interface-local multicast",
|
||||||
|
prefix: netip.MustParsePrefix("ff01::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Multicast",
|
||||||
|
prefix: netip.MustParsePrefix("ff00::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Unspecified with prefix",
|
||||||
|
prefix: netip.MustParsePrefix("::/128"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "IPv4 WireGuard interface network overlap",
|
||||||
|
prefix: netip.MustParsePrefix("100.65.75.0/24"),
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 WireGuard interface network subnet",
|
||||||
|
prefix: netip.MustParsePrefix("100.65.75.0/32"),
|
||||||
|
expectError: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
// todo resolve test execution on freebsd
|
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
t.Skip("skipping ", testCase.name, " on freebsd")
|
|
||||||
}
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||||
|
|
||||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
newNet, err := stdnet.NewNet()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
|
||||||
Address: "100.65.75.2/24",
|
|
||||||
WGPrivKey: peerPrivateKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
wgInterface, err := iface.NewWGIFace(opts)
|
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
|
||||||
defer wgInterface.Close()
|
|
||||||
|
|
||||||
err = wgInterface.Create()
|
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
|
_, _, err := r.SetupRouting(nil, nil)
|
||||||
_, _, err = r.SetupRouting(nil, nil)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||||
require.NoError(t, err, "InterfaceByName should not return err")
|
require.NoError(t, err)
|
||||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
|
||||||
|
|
||||||
|
// add the route
|
||||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
err = r.AddVPNRoute(testCase.prefix, intf)
|
||||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
if testCase.expectError {
|
||||||
|
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if testCase.shouldRouteToWireguard {
|
// validate it's pointing to the WireGuard interface
|
||||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
nextHop := getNextHop(t, testCase.prefix.Addr())
|
||||||
|
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
|
||||||
|
|
||||||
|
// remove route again
|
||||||
|
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// validate it's gone
|
||||||
|
nextHop, err = GetNextHop(testCase.prefix.Addr())
|
||||||
|
require.True(t,
|
||||||
|
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
|
||||||
|
"err: %v, next hop: %v", err, nextHop)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
|
||||||
|
nextHop, err := GetNextHop(addr)
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
|
||||||
|
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
|
||||||
|
// present in the route table.
|
||||||
|
t.Skip("Skipping windows test")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
|
||||||
|
|
||||||
|
return nextHop
|
||||||
|
}
|
||||||
|
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
|
||||||
|
|
||||||
|
if addr.IsUnspecified() {
|
||||||
|
// On macOS, querying 0.0.0.0 returns the wrong interface
|
||||||
|
if addr.Is4() {
|
||||||
|
addr = netip.MustParseAddr("1.2.3.4")
|
||||||
|
} else {
|
||||||
|
addr = netip.MustParseAddr("2001:db8::1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "-n", "get", addr.String())
|
||||||
|
if addr.Is6() {
|
||||||
|
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
t.Logf("route output: %s", output)
|
||||||
|
require.NoError(t, err, "%s failed")
|
||||||
|
|
||||||
|
lines := strings.Split(string(output), "\n")
|
||||||
|
var intf string
|
||||||
|
var gateway string
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if strings.HasPrefix(line, "interface:") {
|
||||||
|
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
|
||||||
|
} else if strings.HasPrefix(line, "gateway:") {
|
||||||
|
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotEmpty(t, intf, "interface should be found in route output")
|
||||||
|
|
||||||
|
iface, err := net.InterfaceByName(intf)
|
||||||
|
require.NoError(t, err, "interface %s should exist", intf)
|
||||||
|
|
||||||
|
nexthop := Nexthop{Intf: iface}
|
||||||
|
|
||||||
|
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
|
||||||
|
addr, err := netip.ParseAddr(gateway)
|
||||||
|
if err == nil {
|
||||||
|
nexthop.IP = addr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nexthop
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
prefix netip.Prefix
|
||||||
|
expectError bool
|
||||||
|
errorType error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 RFC3927 test range",
|
||||||
|
prefix: netip.MustParsePrefix("198.51.100.0/24"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Single host",
|
||||||
|
prefix: netip.MustParsePrefix("8.8.8.8/32"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 External network route",
|
||||||
|
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Single host",
|
||||||
|
prefix: netip.MustParsePrefix("2001:db8::1/128"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Subnet",
|
||||||
|
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Single host",
|
||||||
|
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
|
||||||
|
},
|
||||||
|
|
||||||
|
// Addresses that should be rejected
|
||||||
|
{
|
||||||
|
name: "IPv4 Loopback",
|
||||||
|
prefix: netip.MustParsePrefix("127.0.0.1/32"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Link-local unicast",
|
||||||
|
prefix: netip.MustParsePrefix("169.254.1.1/32"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Multicast",
|
||||||
|
prefix: netip.MustParsePrefix("239.255.255.250/32"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 Unspecified",
|
||||||
|
prefix: netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Loopback",
|
||||||
|
prefix: netip.MustParsePrefix("::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Link-local unicast",
|
||||||
|
prefix: netip.MustParsePrefix("fe80::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Multicast",
|
||||||
|
prefix: netip.MustParsePrefix("ff00::1/128"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 Unspecified",
|
||||||
|
prefix: netip.MustParsePrefix("::/0"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 WireGuard interface network overlap",
|
||||||
|
prefix: netip.MustParsePrefix("100.65.75.0/24"),
|
||||||
|
expectError: true,
|
||||||
|
errorType: vars.ErrRouteNotAllowed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
||||||
|
|
||||||
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
|
r := NewSysOps(wgInterface, nil)
|
||||||
|
_, _, err := r.SetupRouting(nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
|
require.NoError(t, err, "Should be able to get IPv4 default route")
|
||||||
|
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
|
||||||
|
|
||||||
|
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
|
||||||
|
if testCase.prefix.Addr().Is6() &&
|
||||||
|
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
|
||||||
|
t.Skip("Skipping test as no ipv6 default route is available")
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
|
||||||
|
t.Fatalf("Failed to get IPv6 default route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var initialNextHop Nexthop
|
||||||
|
if testCase.prefix.Addr().Is6() {
|
||||||
|
initialNextHop = initialNextHopV6
|
||||||
} else {
|
} else {
|
||||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
initialNextHop = initialNextHopV4
|
||||||
}
|
}
|
||||||
exists, err := existsInRouteTable(testCase.prefix)
|
|
||||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
|
||||||
if exists && testCase.shouldRouteToWireguard {
|
|
||||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
|
||||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
|
||||||
|
|
||||||
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
|
nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
|
||||||
require.NoError(t, err, "GetNextHop should not return err")
|
|
||||||
|
|
||||||
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
if testCase.expectError {
|
||||||
require.NoError(t, err)
|
require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
|
||||||
|
return
|
||||||
if testCase.shouldBeRemoved {
|
|
||||||
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
|
|
||||||
} else {
|
|
||||||
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
|
||||||
|
|
||||||
|
// Verify the route was added and points to non-VPN interface
|
||||||
|
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
|
||||||
|
|
||||||
|
err = r.removeFromRouteTable(testCase.prefix, nexthop)
|
||||||
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetNextHop(t *testing.T) {
|
func TestGetNextHop(t *testing.T) {
|
||||||
if runtime.GOOS == "freebsd" {
|
defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
||||||
t.Skip("skipping on freebsd")
|
|
||||||
}
|
|
||||||
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
}
|
}
|
||||||
if !nexthop.IP.IsValid() {
|
if !defaultNh.IP.IsValid() {
|
||||||
t.Fatal("should return a gateway")
|
t.Fatal("should return a gateway")
|
||||||
}
|
}
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var testingIP string
|
|
||||||
var testingPrefix netip.Prefix
|
var testingPrefix netip.Prefix
|
||||||
for _, address := range addresses {
|
for _, address := range addresses {
|
||||||
if address.Network() != "ip+net" {
|
if address.Network() != "ip+net" {
|
||||||
@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
prefix := netip.MustParsePrefix(address.String())
|
prefix := netip.MustParsePrefix(address.String())
|
||||||
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
|
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
|
||||||
testingIP = prefix.Addr().String()
|
|
||||||
testingPrefix = prefix.Masked()
|
testingPrefix = prefix.Masked()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localIP, err := GetNextHop(testingPrefix.Addr())
|
nh, err := GetNextHop(testingPrefix.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error: ", err)
|
t.Fatal("shouldn't return error: ", err)
|
||||||
}
|
}
|
||||||
if !localIP.IP.IsValid() {
|
if nh.Intf == nil {
|
||||||
t.Fatal("should return a gateway for local network")
|
t.Fatal("should return a gateway for local network")
|
||||||
}
|
}
|
||||||
if localIP.IP.String() == nexthop.IP.String() {
|
if nh.IP.String() == defaultNh.IP.String() {
|
||||||
t.Fatal("local IP should not match with gateway IP")
|
t.Fatal("next hop IP should not match with default gateway IP")
|
||||||
}
|
}
|
||||||
if localIP.IP.String() != testingIP {
|
if nh.Intf.Name != defaultNh.Intf.Name {
|
||||||
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
|
t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
|
||||||
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
|
|
||||||
t.Log("defaultNexthop: ", defaultNexthop)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
|
||||||
}
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
prefix netip.Prefix
|
|
||||||
preExistingPrefix netip.Prefix
|
|
||||||
shouldAddRoute bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Should Add And Remove random Route",
|
|
||||||
prefix: netip.MustParsePrefix("99.99.99.99/32"),
|
|
||||||
shouldAddRoute: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Add Route if overlaps with default gateway",
|
|
||||||
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
|
|
||||||
shouldAddRoute: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Add Route if bigger network exists",
|
|
||||||
prefix: netip.MustParsePrefix("100.100.100.0/24"),
|
|
||||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
|
||||||
shouldAddRoute: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Add Route if smaller network exists",
|
|
||||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
|
||||||
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
|
|
||||||
shouldAddRoute: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Should Not Add Route if same network exists",
|
|
||||||
prefix: netip.MustParsePrefix("100.100.0.0/16"),
|
|
||||||
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
|
|
||||||
shouldAddRoute: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for n, testCase := range testCases {
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
log.SetOutput(&buf)
|
|
||||||
defer func() {
|
|
||||||
log.SetOutput(os.Stderr)
|
|
||||||
}()
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
|
||||||
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
|
|
||||||
|
|
||||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
|
||||||
newNet, err := stdnet.NewNet()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
opts := iface.WGIFaceOpts{
|
|
||||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
|
||||||
Address: "100.65.75.2/24",
|
|
||||||
WGPort: 33100,
|
|
||||||
WGPrivKey: peerPrivateKey.String(),
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
TransportNet: newNet,
|
|
||||||
}
|
|
||||||
wgInterface, err := iface.NewWGIFace(opts)
|
|
||||||
require.NoError(t, err, "should create testing WGIface interface")
|
|
||||||
defer wgInterface.Close()
|
|
||||||
|
|
||||||
err = wgInterface.Create()
|
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
|
||||||
require.NoError(t, err, "InterfaceByName should not return err")
|
|
||||||
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
|
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
|
||||||
|
|
||||||
// Prepare the environment
|
|
||||||
if testCase.preExistingPrefix.IsValid() {
|
|
||||||
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
|
|
||||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the route
|
|
||||||
err = r.AddVPNRoute(testCase.prefix, intf)
|
|
||||||
require.NoError(t, err, "should not return err when adding route")
|
|
||||||
|
|
||||||
if testCase.shouldAddRoute {
|
|
||||||
// test if route exists after adding
|
|
||||||
ok, err := existsInRouteTable(testCase.prefix)
|
|
||||||
require.NoError(t, err, "should not return err")
|
|
||||||
require.True(t, ok, "route should exist")
|
|
||||||
|
|
||||||
// remove route again if added
|
|
||||||
err = r.RemoveVPNRoute(testCase.prefix, intf)
|
|
||||||
require.NoError(t, err, "should not return err")
|
|
||||||
}
|
|
||||||
|
|
||||||
// route should either not have been added or should have been removed
|
|
||||||
// In case of already existing route, it should not have been added (but still exist)
|
|
||||||
ok, err := existsInRouteTable(testCase.prefix)
|
|
||||||
t.Log("Buffer string: ", buf.String())
|
|
||||||
require.NoError(t, err, "should not return err")
|
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "because it already exists") {
|
|
||||||
require.False(t, ok, "route should not exist")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsSubRange(t *testing.T) {
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var subRangeAddressPrefixes []netip.Prefix
|
|
||||||
var nonSubRangeAddressPrefixes []netip.Prefix
|
|
||||||
for _, address := range addresses {
|
|
||||||
p := netip.MustParsePrefix(address.String())
|
|
||||||
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
|
|
||||||
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
|
|
||||||
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
|
|
||||||
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range subRangeAddressPrefixes {
|
|
||||||
isSubRangePrefix, err := isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
|
||||||
}
|
|
||||||
if !isSubRangePrefix {
|
|
||||||
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range nonSubRangeAddressPrefixes {
|
|
||||||
isSubRangePrefix, err := isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
|
|
||||||
}
|
|
||||||
if isSubRangePrefix {
|
|
||||||
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExistsInRouteTable(t *testing.T) {
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addressPrefixes []netip.Prefix
|
|
||||||
for _, address := range addresses {
|
|
||||||
p := netip.MustParsePrefix(address.String())
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case p.Addr().Is6():
|
|
||||||
continue
|
|
||||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
|
||||||
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
|
|
||||||
continue
|
|
||||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
|
||||||
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
|
|
||||||
continue
|
|
||||||
// FreeBSD loopback 127/8 is not added to the routing table
|
|
||||||
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range addressPrefixes {
|
|
||||||
exists, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
t.Fatalf("address %s should exist in route table", prefix)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
|||||||
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
err := r.AddVPNRoute(prefix, intf)
|
if err := r.AddVPNRoute(prefix, intf); err != nil {
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
|
||||||
|
t.Fatalf("addVPNRoute should not return err: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("addVPNRoute %v returned: %v", prefix, err)
|
||||||
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err = r.RemoveVPNRoute(prefix, intf)
|
if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
t.Fatalf("removeVPNRoute should not return err: %v", err)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
|
|||||||
// 10.10.0.0/24 more specific route exists in vpn table
|
// 10.10.0.0/24 more specific route exists in vpn table
|
||||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
|
||||||
|
|
||||||
// 127.0.10.0/24 more specific route exists in vpn table
|
|
||||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
|
|
||||||
|
|
||||||
// unique route in vpn table
|
// unique route in vpn table
|
||||||
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
|
||||||
t.Helper()
|
|
||||||
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixNexthop, err := GetNextHop(prefix.Addr())
|
|
||||||
require.NoError(t, err, "GetNextHop should not return err")
|
|
||||||
if invert {
|
|
||||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsVpnRoute(t *testing.T) {
|
func TestIsVpnRoute(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
if err := r.validateRoute(prefix); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if !nbnet.AdvancedRouting() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.genericAddVPNRoute(prefix, intf)
|
return r.genericAddVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
if err := r.validateRoute(prefix); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if !nbnet.AdvancedRouting() {
|
if !nbnet.AdvancedRouting() {
|
||||||
return r.genericRemoveVPNRoute(prefix, intf)
|
return r.genericRemoveVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
|||||||
|
|
||||||
ones, _ := route.Dst.Mask.Size()
|
ones, _ := route.Dst.Mask.Size()
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(addr, ones)
|
prefix := netip.PrefixFrom(addr.Unmap(), ones)
|
||||||
if prefix.IsValid() {
|
if prefix.IsValid() {
|
||||||
prefixList = append(prefixList, prefix)
|
prefixList = append(prefixList, prefix)
|
||||||
}
|
}
|
||||||
@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
|
|||||||
return fmt.Errorf("add gateway and device: %w", err)
|
return fmt.Errorf("add gateway and device: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink add route: %w", err)
|
return fmt.Errorf("netlink add route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|||||||
Dst: ipNet,
|
Dst: ipNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
|
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
|
||||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
return fmt.Errorf("netlink add unreachable route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var expectedVPNint = "wgtest0"
|
var expectedVPNint = "wgtest0"
|
||||||
var expectedLoopbackInt = "lo"
|
|
||||||
var expectedExternalInt = "dummyext0"
|
var expectedExternalInt = "dummyext0"
|
||||||
var expectedInternalInt = "dummyint0"
|
var expectedInternalInt = "dummyint0"
|
||||||
|
|
||||||
@ -31,12 +30,6 @@ func init() {
|
|||||||
dialer: &net.Dialer{},
|
dialer: &net.Dialer{},
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "To more specific route (local) without custom dialer via physical interface",
|
|
||||||
expectedInterface: expectedLoopbackInt,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
|
||||||
},
|
|
||||||
}...)
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,10 +11,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
if err := r.validateRoute(prefix); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return r.genericAddVPNRoute(prefix, intf)
|
return r.genericAddVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
if err := r.validateRoute(prefix); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return r.genericRemoveVPNRoute(prefix, intf)
|
return r.genericRemoveVPNRoute(prefix, intf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
268
client/internal/routemanager/systemops/systemops_test.go
Normal file
268
client/internal/routemanager/systemops/systemops_test.go
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockWGIface struct {
|
||||||
|
address wgaddr.Address
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockWGIface) Address() wgaddr.Address {
|
||||||
|
return m.address
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockWGIface) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSysOps_validateRoute(t *testing.T) {
|
||||||
|
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
|
||||||
|
mockWG := &mockWGIface{
|
||||||
|
address: wgaddr.Address{
|
||||||
|
IP: wgNetwork.Addr(),
|
||||||
|
Network: wgNetwork,
|
||||||
|
},
|
||||||
|
name: "wg0",
|
||||||
|
}
|
||||||
|
|
||||||
|
sysOps := &SysOps{
|
||||||
|
wgInterface: mockWG,
|
||||||
|
notifier: ¬ifier.Notifier{},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prefix string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
// Valid routes
|
||||||
|
{
|
||||||
|
name: "valid IPv4 route",
|
||||||
|
prefix: "192.168.1.0/24",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid IPv6 route",
|
||||||
|
prefix: "2001:db8::/32",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid single IPv4 host",
|
||||||
|
prefix: "8.8.8.8/32",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid single IPv6 host",
|
||||||
|
prefix: "2001:4860:4860::8888/128",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Invalid routes - loopback
|
||||||
|
{
|
||||||
|
name: "IPv4 loopback",
|
||||||
|
prefix: "127.0.0.1/32",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback",
|
||||||
|
prefix: "::1/128",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Invalid routes - link-local unicast
|
||||||
|
{
|
||||||
|
name: "IPv4 link-local unicast",
|
||||||
|
prefix: "169.254.1.1/32",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 link-local unicast",
|
||||||
|
prefix: "fe80::1/128",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Invalid routes - multicast
|
||||||
|
{
|
||||||
|
name: "IPv4 multicast",
|
||||||
|
prefix: "224.0.0.1/32",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 multicast",
|
||||||
|
prefix: "ff02::1/128",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Invalid routes - link-local multicast
|
||||||
|
{
|
||||||
|
name: "IPv4 link-local multicast",
|
||||||
|
prefix: "224.0.0.0/24",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 link-local multicast",
|
||||||
|
prefix: "ff02::/16",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Invalid routes - interface-local multicast (IPv6 only)
|
||||||
|
{
|
||||||
|
name: "IPv6 interface-local multicast",
|
||||||
|
prefix: "ff01::1/128",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Invalid routes - overlaps with WG interface network
|
||||||
|
{
|
||||||
|
name: "overlaps with WG network - exact match",
|
||||||
|
prefix: "10.0.0.0/24",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overlaps with WG network - subset",
|
||||||
|
prefix: "10.0.0.1/32",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overlaps with WG network - host in range",
|
||||||
|
prefix: "10.0.0.100/32",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
prefix, err := netip.ParsePrefix(tt.prefix)
|
||||||
|
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
|
||||||
|
|
||||||
|
err = sysOps.validateRoute(prefix)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
|
||||||
|
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
|
||||||
|
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
|
||||||
|
mockWG := &mockWGIface{
|
||||||
|
address: wgaddr.Address{
|
||||||
|
IP: wgNetwork.Addr(),
|
||||||
|
Network: wgNetwork,
|
||||||
|
},
|
||||||
|
name: "wg0",
|
||||||
|
}
|
||||||
|
|
||||||
|
sysOps := &SysOps{
|
||||||
|
wgInterface: mockWG,
|
||||||
|
notifier: ¬ifier.Notifier{},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prefix string
|
||||||
|
expectError bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical subnet",
|
||||||
|
prefix: "192.168.100.0/24",
|
||||||
|
expectError: true,
|
||||||
|
description: "exact same network as WG interface",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "broader subnet containing WG network",
|
||||||
|
prefix: "192.168.0.0/16",
|
||||||
|
expectError: false,
|
||||||
|
description: "broader network that contains WG network should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host within WG network",
|
||||||
|
prefix: "192.168.100.50/32",
|
||||||
|
expectError: true,
|
||||||
|
description: "specific host within WG network",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subnet within WG network",
|
||||||
|
prefix: "192.168.100.128/25",
|
||||||
|
expectError: true,
|
||||||
|
description: "smaller subnet within WG network",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "adjacent subnet - same /23",
|
||||||
|
prefix: "192.168.101.0/24",
|
||||||
|
expectError: false,
|
||||||
|
description: "adjacent subnet, no overlap",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "adjacent subnet - different /16",
|
||||||
|
prefix: "192.167.100.0/24",
|
||||||
|
expectError: false,
|
||||||
|
description: "different network, no overlap",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "WG network broadcast address",
|
||||||
|
prefix: "192.168.100.255/32",
|
||||||
|
expectError: true,
|
||||||
|
description: "broadcast address of WG network",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "WG network first usable",
|
||||||
|
prefix: "192.168.100.1/32",
|
||||||
|
expectError: true,
|
||||||
|
description: "first usable address in WG network",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
prefix, err := netip.ParsePrefix(tt.prefix)
|
||||||
|
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
|
||||||
|
|
||||||
|
err = sysOps.validateRoute(prefix)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
|
||||||
|
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
||||||
|
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
|
||||||
|
mockWG := &mockWGIface{
|
||||||
|
address: wgaddr.Address{
|
||||||
|
IP: wgNetwork.Addr(),
|
||||||
|
Network: wgNetwork,
|
||||||
|
},
|
||||||
|
name: "wt0",
|
||||||
|
}
|
||||||
|
|
||||||
|
sysOps := &SysOps{
|
||||||
|
wgInterface: mockWG,
|
||||||
|
notifier: ¬ifier.Notifier{},
|
||||||
|
}
|
||||||
|
|
||||||
|
var invalidPrefix netip.Prefix
|
||||||
|
err := sysOps.validateRoute(invalidPrefix)
|
||||||
|
|
||||||
|
require.Error(t, err, "validateRoute() expected error for invalid prefix")
|
||||||
|
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
|
||||||
|
}
|
@ -3,15 +3,19 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"strconv"
|
||||||
"strings"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
return r.routeCmd("add", prefix, nexthop)
|
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
return r.routeCmd("delete", prefix, nexthop)
|
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
inet := "-inet"
|
if !prefix.IsValid() {
|
||||||
if prefix.Addr().Is6() {
|
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||||
inet = "-inet6"
|
|
||||||
}
|
|
||||||
|
|
||||||
network := prefix.String()
|
|
||||||
if prefix.IsSingleIP() {
|
|
||||||
network = prefix.Addr().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{"-n", action, inet, network}
|
|
||||||
if nexthop.IP.IsValid() {
|
|
||||||
args = append(args, nexthop.IP.Unmap().String())
|
|
||||||
} else if nexthop.Intf != nil {
|
|
||||||
args = append(args, "-interface", nexthop.Intf.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := retryRouteCmd(args); err != nil {
|
|
||||||
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func retryRouteCmd(args []string) error {
|
|
||||||
operation := func() error {
|
|
||||||
out, err := exec.Command("route", args...).CombinedOutput()
|
|
||||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
|
||||||
// https://github.com/golang/go/issues/45736
|
|
||||||
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
|
|
||||||
return err
|
|
||||||
} else if err != nil {
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expBackOff := backoff.NewExponentialBackOff()
|
expBackOff := backoff.NewExponentialBackOff()
|
||||||
@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
|
|||||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||||
|
|
||||||
err := backoff.Retry(operation, expBackOff)
|
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
||||||
if err != nil {
|
a := "add"
|
||||||
return fmt.Errorf("route cmd retry failed: %w", err)
|
if action == unix.RTM_DELETE {
|
||||||
|
a = "remove"
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
||||||
|
operation := func() error {
|
||||||
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open routing socket: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
|
log.Warnf("failed to close routing socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBytes, err := msg.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = unix.Write(fd, msgBytes); err != nil {
|
||||||
|
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||||
|
return fmt.Errorf("write: %w", err)
|
||||||
|
}
|
||||||
|
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
respBuf := make([]byte, 2048)
|
||||||
|
n, err := unix.Read(fd, respBuf)
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return operation
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||||
|
msg = &route.RouteMessage{
|
||||||
|
Type: action,
|
||||||
|
Flags: unix.RTF_UP,
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Seq: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
const numAddrs = unix.RTAX_NETMASK + 1
|
||||||
|
addrs := make([]route.Addr, numAddrs)
|
||||||
|
|
||||||
|
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
msg.Flags |= unix.RTF_HOST
|
||||||
|
} else {
|
||||||
|
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nexthop.IP.IsValid() {
|
||||||
|
msg.Flags |= unix.RTF_GATEWAY
|
||||||
|
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
|
||||||
|
}
|
||||||
|
} else if nexthop.Intf != nil {
|
||||||
|
msg.Index = nexthop.Intf.Index
|
||||||
|
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
||||||
|
Index: nexthop.Intf.Index,
|
||||||
|
Name: nexthop.Intf.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.Addrs = addrs
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||||
|
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||||
|
if rtMsg.Errno != 0 {
|
||||||
|
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||||
|
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||||
|
if addr.Is4() {
|
||||||
|
return &route.Inet4Addr{IP: addr.As4()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.Zone() == "" {
|
||||||
|
return &route.Inet6Addr{IP: addr.As16()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var zone int
|
||||||
|
// zone can be either a numeric zone ID or an interface name.
|
||||||
|
if z, err := strconv.Atoi(addr.Zone()); err == nil {
|
||||||
|
zone = z
|
||||||
|
} else {
|
||||||
|
iface, err := net.InterfaceByName(addr.Zone())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
|
||||||
|
}
|
||||||
|
zone = iface.Index
|
||||||
|
}
|
||||||
|
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
|
||||||
|
bits := prefix.Bits()
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
m := net.CIDRMask(bits, 32)
|
||||||
|
var maskBytes [4]byte
|
||||||
|
copy(maskBytes[:], m)
|
||||||
|
return &route.Inet4Addr{IP: maskBytes}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
m := net.CIDRMask(bits, 128)
|
||||||
|
var maskBytes [16]byte
|
||||||
|
copy(maskBytes[:], m)
|
||||||
|
return &route.Inet6Addr{IP: maskBytes}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
||||||
|
}
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -9,9 +7,8 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@ -21,11 +18,12 @@ import (
|
|||||||
"github.com/yusufpapurcu/wmi"
|
"github.com/yusufpapurcu/wmi"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const InfiniteLifetime = 0xffffffff
|
||||||
|
|
||||||
type RouteUpdateType int
|
type RouteUpdateType int
|
||||||
|
|
||||||
// RouteUpdate represents a change in the routing table.
|
// RouteUpdate represents a change in the routing table.
|
||||||
@ -58,9 +56,13 @@ type MSFT_NetRoute struct {
|
|||||||
AddressFamily uint16
|
AddressFamily uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
|
// luid represents a locally unique identifier for network interfaces
|
||||||
|
type luid uint64
|
||||||
|
|
||||||
|
// MIB_IPFORWARD_ROW2 represents a route entry in the routing table.
|
||||||
|
// It is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
|
||||||
type MIB_IPFORWARD_ROW2 struct {
|
type MIB_IPFORWARD_ROW2 struct {
|
||||||
InterfaceLuid uint64
|
InterfaceLuid luid
|
||||||
InterfaceIndex uint32
|
InterfaceIndex uint32
|
||||||
DestinationPrefix IP_ADDRESS_PREFIX
|
DestinationPrefix IP_ADDRESS_PREFIX
|
||||||
NextHop SOCKADDR_INET_NEXTHOP
|
NextHop SOCKADDR_INET_NEXTHOP
|
||||||
@ -108,9 +110,14 @@ type SOCKADDR_INET_NEXTHOP struct {
|
|||||||
type MIB_NOTIFICATION_TYPE int32
|
type MIB_NOTIFICATION_TYPE int32
|
||||||
|
|
||||||
var (
|
var (
|
||||||
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||||
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
|
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
|
||||||
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
|
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
|
||||||
|
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
|
||||||
|
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
|
||||||
|
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
|
||||||
|
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
|
||||||
|
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||||
|
|
||||||
prefixList []netip.Prefix
|
prefixList []netip.Prefix
|
||||||
lastUpdate time.Time
|
lastUpdate time.Time
|
||||||
@ -139,6 +146,8 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
log.Debugf("Adding route to %s via %s", prefix, nexthop)
|
||||||
|
// if we don't have an interface but a zone, extract the interface index from the zone
|
||||||
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
|
||||||
zone, err := strconv.Atoi(nexthop.IP.Zone())
|
zone, err := strconv.Atoi(nexthop.IP.Zone())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -147,23 +156,187 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
|||||||
nexthop.Intf = &net.Interface{Index: zone}
|
nexthop.Intf = &net.Interface{Index: zone}
|
||||||
}
|
}
|
||||||
|
|
||||||
return addRouteCmd(prefix, nexthop)
|
return addRoute(prefix, nexthop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
args := []string{"delete", prefix.String()}
|
log.Debugf("Removing route to %s via %s", prefix, nexthop)
|
||||||
if nexthop.IP.IsValid() {
|
return deleteRoute(prefix, nexthop)
|
||||||
ip := nexthop.IP.WithZone("")
|
}
|
||||||
args = append(args, ip.Unmap().String())
|
|
||||||
|
// setupRouteEntry prepares a route entry with common configuration
|
||||||
|
func setupRouteEntry(prefix netip.Prefix, nexthop Nexthop) (*MIB_IPFORWARD_ROW2, error) {
|
||||||
|
route := &MIB_IPFORWARD_ROW2{}
|
||||||
|
|
||||||
|
initializeIPForwardEntry(route)
|
||||||
|
|
||||||
|
// Convert interface index to luid if interface is specified
|
||||||
|
if nexthop.Intf != nil {
|
||||||
|
var luid luid
|
||||||
|
if err := convertInterfaceIndexToLUID(uint32(nexthop.Intf.Index), &luid); err != nil {
|
||||||
|
return nil, fmt.Errorf("convert interface index to luid: %w", err)
|
||||||
|
}
|
||||||
|
route.InterfaceLuid = luid
|
||||||
|
route.InterfaceIndex = uint32(nexthop.Intf.Index)
|
||||||
}
|
}
|
||||||
|
|
||||||
routeCmd := uspfilter.GetSystem32Command("route")
|
if err := setDestinationPrefix(&route.DestinationPrefix, prefix); err != nil {
|
||||||
|
return nil, fmt.Errorf("set destination prefix: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
if nexthop.IP.IsValid() {
|
||||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
if err := setNextHop(&route.NextHop, nexthop.IP); err != nil {
|
||||||
|
return nil, fmt.Errorf("set next hop: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
return route, nil
|
||||||
return fmt.Errorf("remove route: %w", err)
|
}
|
||||||
|
|
||||||
|
// addRoute adds a route using Windows iphelper APIs
|
||||||
|
func addRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic in addRoute: %v, stack trace: %s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
route, setupErr := setupRouteEntry(prefix, nexthop)
|
||||||
|
if setupErr != nil {
|
||||||
|
return fmt.Errorf("setup route entry: %w", setupErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
route.Metric = 1
|
||||||
|
route.ValidLifetime = InfiniteLifetime
|
||||||
|
route.PreferredLifetime = InfiniteLifetime
|
||||||
|
|
||||||
|
return createIPForwardEntry2(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteRoute deletes a route using Windows iphelper APIs
|
||||||
|
func deleteRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic in deleteRoute: %v, stack trace: %s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
route, setupErr := setupRouteEntry(prefix, nexthop)
|
||||||
|
if setupErr != nil {
|
||||||
|
return fmt.Errorf("setup route entry: %w", setupErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := getIPForwardEntry2(route); err != nil {
|
||||||
|
return fmt.Errorf("get route entry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return deleteIPForwardEntry2(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setDestinationPrefix sets the destination prefix in the route structure
|
||||||
|
func setDestinationPrefix(prefix *IP_ADDRESS_PREFIX, dest netip.Prefix) error {
|
||||||
|
addr := dest.Addr()
|
||||||
|
prefix.PrefixLength = uint8(dest.Bits())
|
||||||
|
|
||||||
|
if addr.Is4() {
|
||||||
|
prefix.Prefix.sin6_family = windows.AF_INET
|
||||||
|
ip4 := addr.As4()
|
||||||
|
binary.BigEndian.PutUint32(prefix.Prefix.data[:4],
|
||||||
|
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.Is6() {
|
||||||
|
prefix.Prefix.sin6_family = windows.AF_INET6
|
||||||
|
ip6 := addr.As16()
|
||||||
|
copy(prefix.Prefix.data[4:20], ip6[:])
|
||||||
|
|
||||||
|
if zone := addr.Zone(); zone != "" {
|
||||||
|
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
|
||||||
|
binary.BigEndian.PutUint32(prefix.Prefix.data[20:24], uint32(scopeID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid address family")
|
||||||
|
}
|
||||||
|
|
||||||
|
// setNextHop sets the next hop address in the route structure
|
||||||
|
func setNextHop(nextHop *SOCKADDR_INET_NEXTHOP, addr netip.Addr) error {
|
||||||
|
if addr.Is4() {
|
||||||
|
nextHop.sin6_family = windows.AF_INET
|
||||||
|
ip4 := addr.As4()
|
||||||
|
binary.BigEndian.PutUint32(nextHop.data[:4],
|
||||||
|
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if addr.Is6() {
|
||||||
|
nextHop.sin6_family = windows.AF_INET6
|
||||||
|
ip6 := addr.As16()
|
||||||
|
copy(nextHop.data[4:20], ip6[:])
|
||||||
|
|
||||||
|
// Handle zone if present
|
||||||
|
if zone := addr.Zone(); zone != "" {
|
||||||
|
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
|
||||||
|
binary.BigEndian.PutUint32(nextHop.data[20:24], uint32(scopeID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("invalid address family")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Windows API wrappers
|
||||||
|
func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
|
||||||
|
if r1 != 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
|
||||||
|
if r1 != 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
return fmt.Errorf("DeleteIpForwardEntry2: %w", e1)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("DeleteIpForwardEntry2: code %d", r1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
|
||||||
|
if r1 != 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
return fmt.Errorf("GetIpForwardEntry2: %w", e1)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("GetIpForwardEntry2: code %d", r1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-initializeipforwardentry
|
||||||
|
func initializeIPForwardEntry(route *MIB_IPFORWARD_ROW2) {
|
||||||
|
// Does not return anything. Trying to handle the error might return an uninitialized value.
|
||||||
|
_, _, _ = syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *luid) error {
|
||||||
|
r1, _, e1 := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(),
|
||||||
|
uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)))
|
||||||
|
if r1 != 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
return fmt.Errorf("ConvertInterfaceIndexToLuid: %w", e1)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("ConvertInterfaceIndexToLuid: code %d", r1)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -319,7 +492,7 @@ func cancelMibChangeNotify2(handle windows.Handle) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRoutesFromTable returns the current routing table from with prefixes only.
|
// GetRoutesFromTable returns the current routing table from with prefixes only.
|
||||||
// It ccaches the result for 2 seconds to avoid blocking the caller.
|
// It caches the result for 2 seconds to avoid blocking the caller.
|
||||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
mux.Lock()
|
mux.Lock()
|
||||||
defer mux.Unlock()
|
defer mux.Unlock()
|
||||||
@ -388,35 +561,6 @@ func GetRoutes() ([]Route, error) {
|
|||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
|
|
||||||
args := []string{"add", prefix.String()}
|
|
||||||
|
|
||||||
if nexthop.IP.IsValid() {
|
|
||||||
ip := nexthop.IP.WithZone("")
|
|
||||||
args = append(args, ip.Unmap().String())
|
|
||||||
} else {
|
|
||||||
addr := "0.0.0.0"
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
addr = "::"
|
|
||||||
}
|
|
||||||
args = append(args, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if nexthop.Intf != nil {
|
|
||||||
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
|
|
||||||
}
|
|
||||||
|
|
||||||
routeCmd := uspfilter.GetSystem32Command("route")
|
|
||||||
|
|
||||||
out, err := exec.Command(routeCmd, args...).CombinedOutput()
|
|
||||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route add: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isCacheDisabled() bool {
|
func isCacheDisabled() bool {
|
||||||
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
|
||||||
}
|
}
|
||||||
|
@ -5,18 +5,23 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
var expectedExtInt = "Ethernet1"
|
var (
|
||||||
|
expectedExternalInt = "Ethernet1"
|
||||||
|
expectedVPNint = "wgtest0"
|
||||||
|
)
|
||||||
|
|
||||||
type RouteInfo struct {
|
type RouteInfo struct {
|
||||||
NextHop string `json:"nexthop"`
|
NextHop string `json:"nexthop"`
|
||||||
@ -43,8 +48,6 @@ type testCase struct {
|
|||||||
dialer dialer
|
dialer dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
var expectedVPNint = "wgtest0"
|
|
||||||
|
|
||||||
var testCases = []testCase{
|
var testCases = []testCase{
|
||||||
{
|
{
|
||||||
name: "To external host without custom dialer via vpn",
|
name: "To external host without custom dialer via vpn",
|
||||||
@ -52,14 +55,14 @@ var testCases = []testCase{
|
|||||||
expectedSourceIP: "100.64.0.1",
|
expectedSourceIP: "100.64.0.1",
|
||||||
expectedDestPrefix: "128.0.0.0/1",
|
expectedDestPrefix: "128.0.0.0/1",
|
||||||
expectedNextHop: "0.0.0.0",
|
expectedNextHop: "0.0.0.0",
|
||||||
expectedInterface: "wgtest0",
|
expectedInterface: expectedVPNint,
|
||||||
dialer: &net.Dialer{},
|
dialer: &net.Dialer{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "To external host with custom dialer via physical interface",
|
name: "To external host with custom dialer via physical interface",
|
||||||
destination: "192.0.2.1:53",
|
destination: "192.0.2.1:53",
|
||||||
expectedDestPrefix: "192.0.2.1/32",
|
expectedDestPrefix: "192.0.2.1/32",
|
||||||
expectedInterface: expectedExtInt,
|
expectedInterface: expectedExternalInt,
|
||||||
dialer: nbnet.NewDialer(),
|
dialer: nbnet.NewDialer(),
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -67,24 +70,15 @@ var testCases = []testCase{
|
|||||||
name: "To duplicate internal route with custom dialer via physical interface",
|
name: "To duplicate internal route with custom dialer via physical interface",
|
||||||
destination: "10.0.0.2:53",
|
destination: "10.0.0.2:53",
|
||||||
expectedDestPrefix: "10.0.0.2/32",
|
expectedDestPrefix: "10.0.0.2/32",
|
||||||
expectedInterface: expectedExtInt,
|
expectedInterface: expectedExternalInt,
|
||||||
dialer: nbnet.NewDialer(),
|
dialer: nbnet.NewDialer(),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
|
||||||
destination: "10.0.0.2:53",
|
|
||||||
expectedSourceIP: "127.0.0.1",
|
|
||||||
expectedDestPrefix: "10.0.0.0/8",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
{
|
||||||
name: "To unique vpn route with custom dialer via physical interface",
|
name: "To unique vpn route with custom dialer via physical interface",
|
||||||
destination: "172.16.0.2:53",
|
destination: "172.16.0.2:53",
|
||||||
expectedDestPrefix: "172.16.0.2/32",
|
expectedDestPrefix: "172.16.0.2/32",
|
||||||
expectedInterface: expectedExtInt,
|
expectedInterface: expectedExternalInt,
|
||||||
dialer: nbnet.NewDialer(),
|
dialer: nbnet.NewDialer(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -93,7 +87,7 @@ var testCases = []testCase{
|
|||||||
expectedSourceIP: "100.64.0.1",
|
expectedSourceIP: "100.64.0.1",
|
||||||
expectedDestPrefix: "172.16.0.0/12",
|
expectedDestPrefix: "172.16.0.0/12",
|
||||||
expectedNextHop: "0.0.0.0",
|
expectedNextHop: "0.0.0.0",
|
||||||
expectedInterface: "wgtest0",
|
expectedInterface: expectedVPNint,
|
||||||
dialer: &net.Dialer{},
|
dialer: &net.Dialer{},
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -103,22 +97,14 @@ var testCases = []testCase{
|
|||||||
expectedSourceIP: "100.64.0.1",
|
expectedSourceIP: "100.64.0.1",
|
||||||
expectedDestPrefix: "10.10.0.0/24",
|
expectedDestPrefix: "10.10.0.0/24",
|
||||||
expectedNextHop: "0.0.0.0",
|
expectedNextHop: "0.0.0.0",
|
||||||
expectedInterface: "wgtest0",
|
expectedInterface: expectedVPNint,
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To more specific route (local) without custom dialer via physical interface",
|
|
||||||
destination: "127.0.10.2:53",
|
|
||||||
expectedSourceIP: "127.0.0.1",
|
|
||||||
expectedDestPrefix: "127.0.0.0/8",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
|
||||||
dialer: &net.Dialer{},
|
dialer: &net.Dialer{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
func TestRouting(t *testing.T) {
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
setupTestEnv(t)
|
setupTestEnv(t)
|
||||||
@ -129,7 +115,7 @@ func TestRouting(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to fetch interface IP")
|
require.NoError(t, err, "Failed to fetch interface IP")
|
||||||
|
|
||||||
output := testRoute(t, tc.destination, tc.dialer)
|
output := testRoute(t, tc.destination, tc.dialer)
|
||||||
if tc.expectedInterface == expectedExtInt {
|
if tc.expectedInterface == expectedExternalInt {
|
||||||
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
||||||
} else {
|
} else {
|
||||||
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
||||||
@ -242,19 +228,23 @@ func setupDummyInterfacesAndRoutes(t *testing.T) {
|
|||||||
func addDummyRoute(t *testing.T, dstCIDR string) {
|
func addDummyRoute(t *testing.T, dstCIDR string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
|
prefix, err := netip.ParsePrefix(dstCIDR)
|
||||||
|
|
||||||
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
|
t.Fatalf("Failed to parse destination CIDR %s: %v", dstCIDR, err)
|
||||||
t.FailNow()
|
}
|
||||||
|
|
||||||
|
nexthop := Nexthop{
|
||||||
|
Intf: &net.Interface{Index: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = addRoute(prefix, nexthop); err != nil {
|
||||||
|
t.Fatalf("Failed to add dummy route: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
|
err := deleteRoute(prefix, nexthop)
|
||||||
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
|
t.Logf("Failed to remove dummy route: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,11 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,81 +19,32 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
if s.connectClient == nil {
|
tracer, engine, err := s.getPacketTracer()
|
||||||
return nil, fmt.Errorf("connect client not initialized")
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
engine := s.connectClient.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil, fmt.Errorf("engine not initialized")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fwManager := engine.GetFirewallManager()
|
srcAddr, err := s.parseAddress(req.GetSourceIp(), engine)
|
||||||
if fwManager == nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("firewall manager not initialized")
|
return nil, fmt.Errorf("invalid source IP address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tracer, ok := fwManager.(packetTracer)
|
dstAddr, err := s.parseAddress(req.GetDestinationIp(), engine)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("firewall manager does not support packet tracing")
|
return nil, fmt.Errorf("invalid destination IP address: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP := net.ParseIP(req.GetSourceIp())
|
protocol, err := s.parseProtocol(req.GetProtocol())
|
||||||
if req.GetSourceIp() == "self" {
|
if err != nil {
|
||||||
srcIP = engine.GetWgAddr()
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
srcAddr, ok := netip.AddrFromSlice(srcIP)
|
direction, err := s.parseDirection(req.GetDirection())
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid source IP address")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dstIP := net.ParseIP(req.GetDestinationIp())
|
tcpState := s.parseTCPFlags(req.GetTcpFlags())
|
||||||
if req.GetDestinationIp() == "self" {
|
|
||||||
dstIP = engine.GetWgAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
dstAddr, ok := netip.AddrFromSlice(dstIP)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid source IP address")
|
|
||||||
}
|
|
||||||
|
|
||||||
if srcIP == nil || dstIP == nil {
|
|
||||||
return nil, fmt.Errorf("invalid IP address")
|
|
||||||
}
|
|
||||||
|
|
||||||
var tcpState *uspfilter.TCPState
|
|
||||||
if flags := req.GetTcpFlags(); flags != nil {
|
|
||||||
tcpState = &uspfilter.TCPState{
|
|
||||||
SYN: flags.GetSyn(),
|
|
||||||
ACK: flags.GetAck(),
|
|
||||||
FIN: flags.GetFin(),
|
|
||||||
RST: flags.GetRst(),
|
|
||||||
PSH: flags.GetPsh(),
|
|
||||||
URG: flags.GetUrg(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var dir fw.RuleDirection
|
|
||||||
switch req.GetDirection() {
|
|
||||||
case "in":
|
|
||||||
dir = fw.RuleDirectionIN
|
|
||||||
case "out":
|
|
||||||
dir = fw.RuleDirectionOUT
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid direction")
|
|
||||||
}
|
|
||||||
|
|
||||||
var protocol fw.Protocol
|
|
||||||
switch req.GetProtocol() {
|
|
||||||
case "tcp":
|
|
||||||
protocol = fw.ProtocolTCP
|
|
||||||
case "udp":
|
|
||||||
protocol = fw.ProtocolUDP
|
|
||||||
case "icmp":
|
|
||||||
protocol = fw.ProtocolICMP
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid protocolcol")
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := &uspfilter.PacketBuilder{
|
builder := &uspfilter.PacketBuilder{
|
||||||
SrcIP: srcAddr,
|
SrcIP: srcAddr,
|
||||||
@ -101,16 +52,96 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
SrcPort: uint16(req.GetSourcePort()),
|
SrcPort: uint16(req.GetSourcePort()),
|
||||||
DstPort: uint16(req.GetDestinationPort()),
|
DstPort: uint16(req.GetDestinationPort()),
|
||||||
Direction: dir,
|
Direction: direction,
|
||||||
TCPState: tcpState,
|
TCPState: tcpState,
|
||||||
ICMPType: uint8(req.GetIcmpType()),
|
ICMPType: uint8(req.GetIcmpType()),
|
||||||
ICMPCode: uint8(req.GetIcmpCode()),
|
ICMPCode: uint8(req.GetIcmpCode()),
|
||||||
}
|
}
|
||||||
|
|
||||||
trace, err := tracer.TracePacketFromBuilder(builder)
|
trace, err := tracer.TracePacketFromBuilder(builder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("trace packet: %w", err)
|
return nil, fmt.Errorf("trace packet: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s.buildTraceResponse(trace), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getPacketTracer() (packetTracer, *internal.Engine, error) {
|
||||||
|
if s.connectClient == nil {
|
||||||
|
return nil, nil, fmt.Errorf("connect client not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := s.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, nil, fmt.Errorf("engine not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwManager := engine.GetFirewallManager()
|
||||||
|
if fwManager == nil {
|
||||||
|
return nil, nil, fmt.Errorf("firewall manager not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
tracer, ok := fwManager.(packetTracer)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("firewall manager does not support packet tracing")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tracer, engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseAddress(addr string, engine *internal.Engine) (netip.Addr, error) {
|
||||||
|
if addr == "self" {
|
||||||
|
return engine.GetWgAddr(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := netip.ParseAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.Unmap(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseProtocol(protocol string) (fw.Protocol, error) {
|
||||||
|
switch protocol {
|
||||||
|
case "tcp":
|
||||||
|
return fw.ProtocolTCP, nil
|
||||||
|
case "udp":
|
||||||
|
return fw.ProtocolUDP, nil
|
||||||
|
case "icmp":
|
||||||
|
return fw.ProtocolICMP, nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("invalid protocol")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseDirection(direction string) (fw.RuleDirection, error) {
|
||||||
|
switch direction {
|
||||||
|
case "in":
|
||||||
|
return fw.RuleDirectionIN, nil
|
||||||
|
case "out":
|
||||||
|
return fw.RuleDirectionOUT, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid direction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseTCPFlags(flags *proto.TCPFlags) *uspfilter.TCPState {
|
||||||
|
if flags == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &uspfilter.TCPState{
|
||||||
|
SYN: flags.GetSyn(),
|
||||||
|
ACK: flags.GetAck(),
|
||||||
|
FIN: flags.GetFin(),
|
||||||
|
RST: flags.GetRst(),
|
||||||
|
PSH: flags.GetPsh(),
|
||||||
|
URG: flags.GetUrg(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) buildTraceResponse(trace *uspfilter.PacketTrace) *proto.TracePacketResponse {
|
||||||
resp := &proto.TracePacketResponse{}
|
resp := &proto.TracePacketResponse{}
|
||||||
|
|
||||||
for _, result := range trace.Results {
|
for _, result := range trace.Results {
|
||||||
@ -119,10 +150,12 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
Message: result.Message,
|
Message: result.Message,
|
||||||
Allowed: result.Allowed,
|
Allowed: result.Allowed,
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.ForwarderAction != nil {
|
if result.ForwarderAction != nil {
|
||||||
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr)
|
||||||
stage.ForwardingDetails = &details
|
stage.ForwardingDetails = &details
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Stages = append(resp.Stages, stage)
|
resp.Stages = append(resp.Stages, stage)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,5 +163,5 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
|
|||||||
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -54,11 +56,13 @@ func GenerateConnID() ConnectionID {
|
|||||||
return ConnectionID(uuid.NewString())
|
return ConnectionID(uuid.NewString())
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
|
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
||||||
// Calculate the last IP in the CIDR range
|
|
||||||
var endIP net.IP
|
var endIP net.IP
|
||||||
for i := 0; i < len(network.IP); i++ {
|
addr := network.Addr().AsSlice()
|
||||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
mask := net.CIDRMask(network.Bits(), len(addr)*8)
|
||||||
|
|
||||||
|
for i := 0; i < len(addr); i++ {
|
||||||
|
endIP = append(endIP, addr[i]|^mask[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert to big.Int
|
// convert to big.Int
|
||||||
@ -70,5 +74,10 @@ func GetLastIPFromNetwork(network *net.IPNet, fromEnd int) net.IP {
|
|||||||
resultInt := big.NewInt(0)
|
resultInt := big.NewInt(0)
|
||||||
resultInt.Sub(endInt, fromEndBig)
|
resultInt.Sub(endInt, fromEndBig)
|
||||||
|
|
||||||
return resultInt.Bytes()
|
ip, ok := netip.AddrFromSlice(resultInt.Bytes())
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ip.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
94
util/net/net_test.go
Normal file
94
util/net/net_test.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
network string
|
||||||
|
fromEnd int
|
||||||
|
expected string
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 /24 network - last IP (fromEnd=0)",
|
||||||
|
network: "192.168.1.0/24",
|
||||||
|
fromEnd: 0,
|
||||||
|
expected: "192.168.1.255",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 /24 network - fromEnd=1",
|
||||||
|
network: "192.168.1.0/24",
|
||||||
|
fromEnd: 1,
|
||||||
|
expected: "192.168.1.254",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 /24 network - fromEnd=5",
|
||||||
|
network: "192.168.1.0/24",
|
||||||
|
fromEnd: 5,
|
||||||
|
expected: "192.168.1.250",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 /16 network - last IP",
|
||||||
|
network: "10.0.0.0/16",
|
||||||
|
fromEnd: 0,
|
||||||
|
expected: "10.0.255.255",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 /16 network - fromEnd=256",
|
||||||
|
network: "10.0.0.0/16",
|
||||||
|
fromEnd: 256,
|
||||||
|
expected: "10.0.254.255",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 /32 network - single host",
|
||||||
|
network: "192.168.1.100/32",
|
||||||
|
fromEnd: 0,
|
||||||
|
expected: "192.168.1.100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 /64 network - last IP",
|
||||||
|
network: "2001:db8::/64",
|
||||||
|
fromEnd: 0,
|
||||||
|
expected: "2001:db8::ffff:ffff:ffff:ffff",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 /64 network - fromEnd=1",
|
||||||
|
network: "2001:db8::/64",
|
||||||
|
fromEnd: 1,
|
||||||
|
expected: "2001:db8::ffff:ffff:ffff:fffe",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 /128 network - single host",
|
||||||
|
network: "2001:db8::1/128",
|
||||||
|
fromEnd: 0,
|
||||||
|
expected: "2001:db8::1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
network, err := netip.ParsePrefix(tt.network)
|
||||||
|
require.NoError(t, err, "Failed to parse network prefix")
|
||||||
|
|
||||||
|
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedIP, err := netip.ParseAddr(tt.expected)
|
||||||
|
require.NoError(t, err, "Failed to parse expected IP")
|
||||||
|
|
||||||
|
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user