Feat fake dns address (#902)

Works only with userspace implementation:
1. Configure host to solve DNS requests via a fake DSN server address in the Netbird network.
2. Add to firewall catch rule for these DNS requests.
3. Resolve these DNS requests and respond by writing directly to wireguard device.
This commit is contained in:
Givi Khojanashvili 2023-06-08 13:46:57 +04:00 committed by GitHub
parent 2c9583dfe1
commit 1d9feab2d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 721 additions and 57 deletions

View File

@ -20,6 +20,8 @@ type Rule struct {
dPort uint16
drop bool
comment string
udpHook func([]byte) bool
}
// GetRuleID returns the rule id

View File

@ -18,7 +18,7 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFiltering(iface.PacketFilter) error
SetFilter(iface.PacketFilter) error
}
// Manager userspace firewall manager
@ -64,7 +64,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
},
}
if err := iface.SetFiltering(m); err != nil {
if err := iface.SetFilter(m); err != nil {
return nil, err
}
return m, nil
@ -273,6 +273,12 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
return rule.drop
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.udpHook(packetData)
}
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop
}
@ -296,3 +302,58 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := Rule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: dPort,
ipLayer: layers.LayerTypeIPv6,
direction: fw.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook,
}
if ip.To4() != nil {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
var toUpdate []Rule
if in {
r.direction = fw.RuleDirectionIN
m.incomingRules = append([]Rule{r}, m.incomingRules...)
toUpdate = m.incomingRules
} else {
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
toUpdate = m.outgoingRules
}
for i := range toUpdate {
m.rulesIndex[toUpdate[i].id] = i
}
m.mutex.Unlock()
return r.id
}
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range m.incomingRules {
if r.id == hookID {
return m.DeleteRule(&r)
}
}
for _, r := range m.outgoingRules {
if r.id == hookID {
return m.DeleteRule(&r)
}
}
return fmt.Errorf("hook with given id not found")
}

View File

@ -15,19 +15,19 @@ import (
)
type IFaceMock struct {
SetFilteringFunc func(iface.PacketFilter) error
SetFilterFunc func(iface.PacketFilter) error
}
func (i *IFaceMock) SetFiltering(iface iface.PacketFilter) error {
if i.SetFilteringFunc == nil {
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented")
}
return i.SetFilteringFunc(iface)
return i.SetFilterFunc(iface)
}
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(iface.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -42,10 +42,10 @@ func TestManagerCreate(t *testing.T) {
}
func TestManagerAddFiltering(t *testing.T) {
isSetFilteringCalled := false
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error {
isSetFilteringCalled = true
SetFilterFunc: func(iface.PacketFilter) error {
isSetFilterCalled = true
return nil
},
}
@ -74,15 +74,15 @@ func TestManagerAddFiltering(t *testing.T) {
return
}
if !isSetFilteringCalled {
t.Error("SetFiltering was not called")
if !isSetFilterCalled {
t.Error("SetFilter was not called")
return
}
}
func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(iface.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -138,9 +138,97 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip net.IP
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback,
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &Manager{
incomingRules: []Rule{},
outgoingRules: []Rule{},
rulesIndex: make(map[string]int),
}
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule Rule
if tt.in {
if len(manager.incomingRules) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
addedRule = manager.incomingRules[0]
} else {
if len(manager.outgoingRules) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
addedRule = manager.outgoingRules[0]
}
if !tt.ip.Equal(addedRule.ip) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
// Ensure rulesIndex is correctly updated
index, ok := manager.rulesIndex[addedRule.id]
if !ok {
t.Errorf("expected rule to be in rulesIndex")
return
}
if index != 0 {
t.Errorf("expected rule index to be 0, got %d", index)
return
}
})
}
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(iface.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -175,7 +263,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(iface.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@ -239,12 +327,56 @@ func TestNotMatchByIP(t *testing.T) {
}
}
// TestRemovePacketHook tests the functionality of the RemovePacketHook method
func TestRemovePacketHook(t *testing.T) {
// creating mock iface
iface := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },
}
// creating manager instance
manager, err := Create(iface)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, rule := range manager.outgoingRules {
if rule.id == hookID {
found = true
break
}
}
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, rule := range manager.outgoingRules {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
ifaceMock := &IFaceMock{
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
SetFilterFunc: func(iface.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
require.NoError(t, err)

View File

@ -19,7 +19,7 @@ type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFiltering(iface.PacketFilter) error
SetFilter(iface.PacketFilter) error
}
// Manager is a ACL rules manager

View File

@ -35,7 +35,7 @@ func TestDefaultManager(t *testing.T) {
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFiltering(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any())
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface)
@ -311,7 +311,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
iface := mocks.NewMockIFaceMapper(ctrl)
iface.EXPECT().IsUserspaceBind().Return(true)
// iface.EXPECT().Name().Return("lo")
iface.EXPECT().SetFiltering(gomock.Any())
iface.EXPECT().SetFilter(gomock.Any())
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(iface)

View File

@ -76,16 +76,16 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name))
}
// SetFiltering mocks base method.
func (m *MockIFaceMapper) SetFiltering(arg0 iface.PacketFilter) error {
// SetFilter mocks base method.
func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetFiltering", arg0)
ret := m.ctrl.Call(m, "SetFilter", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetFiltering indicates an expected call of SetFiltering.
func (mr *MockIFaceMapperMockRecorder) SetFiltering(arg0 interface{}) *gomock.Call {
// SetFilter indicates an expected call of SetFilter.
func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFiltering", reflect.TypeOf((*MockIFaceMapper)(nil).SetFiltering), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
}

View File

@ -0,0 +1,103 @@
package dns
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
)
type responseWriter struct {
local net.Addr
remote net.Addr
packet gopacket.Packet
device tun.Device
}
// LocalAddr returns the net.Addr of the server
func (r *responseWriter) LocalAddr() net.Addr {
return r.local
}
// RemoteAddr returns the net.Addr of the client that sent the current request.
func (r *responseWriter) RemoteAddr() net.Addr {
return r.remote
}
// WriteMsg writes a reply back to the client.
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
buff, err := msg.Pack()
if err != nil {
return err
}
_, err = r.Write(buff)
return err
}
// Write writes a raw buffer back to the client.
func (r *responseWriter) Write(data []byte) (int, error) {
var ip gopacket.SerializableLayer
// Get the UDP layer
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
// Swap the source and destination addresses for the response
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
// Check if it's an IPv4 packet
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
ipv4 := ipv4Layer.(*layers.IPv4)
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
ip = ipv4
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
ipv6 := ipv6Layer.(*layers.IPv6)
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
ip = ipv6
}
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
}
// Serialize the packet
buffer := gopacket.NewSerializeBuffer()
options := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
if err != nil {
return 0, fmt.Errorf("failed to serialize packet: %v", err)
}
send := buffer.Bytes()
sendBuffer := make([]byte, 40, len(send)+40)
sendBuffer = append(sendBuffer, send...)
return r.device.Write([][]byte{sendBuffer}, 40)
}
// Close closes the connection.
func (r *responseWriter) Close() error {
return nil
}
// TsigStatus returns the status of the Tsig.
func (r *responseWriter) TsigStatus() error {
return nil
}
// TsigTimersOnly sets the tsig timers only boolean.
func (r *responseWriter) TsigTimersOnly(bool) {
}
// Hijack lets the caller take over the connection.
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}

View File

@ -0,0 +1,93 @@
package dns
import (
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/iface/mocks"
)
func TestResponseWriterLocalAddr(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
device := mocks.NewMockDevice(ctrl)
device.EXPECT().Write(gomock.Any(), gomock.Any())
request := &dns.Msg{
Question: []dns.Question{{
Name: "google.com.",
Qtype: dns.TypeA,
Qclass: dns.TypeA,
}},
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(request)
replyMessage.RecursionAvailable = true
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = []dns.RR{
&dns.A{
A: net.IPv4(8, 8, 8, 8),
},
}
ipv4 := &layers.IPv4{
Protocol: layers.IPProtocolUDP,
SrcIP: net.IPv4(127, 0, 0, 1),
DstIP: net.IPv4(127, 0, 0, 2),
}
udp := &layers.UDP{
DstPort: 53,
SrcPort: 45223,
}
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
t.Error("failed to set network layer for checksum")
return
}
// Serialize the packet
buffer := gopacket.NewSerializeBuffer()
options := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
requestData, err := request.Pack()
if err != nil {
t.Errorf("got an error while packing the request message, error: %v", err)
return
}
payload := gopacket.Payload(requestData)
if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil {
t.Errorf("failed to serialize packet: %v", err)
return
}
rw := &responseWriter{
local: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 55223,
},
remote: &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 53,
},
packet: gopacket.NewPacket(
buffer.Bytes(),
layers.LayerTypeIPv4,
gopacket.Default,
),
device: device,
}
if err := rw.WriteMsg(replyMessage); err != nil {
t.Errorf("got an error while writing the local resolver response, error: %v", err)
return
}
}

View File

@ -5,12 +5,15 @@ package dns
import (
"context"
"fmt"
"math/big"
"net"
"net/netip"
"runtime"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
@ -33,6 +36,7 @@ type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
mux sync.Mutex
fakeResolverWG sync.WaitGroup
server *dns.Server
dnsMux *dns.ServeMux
dnsMuxMap registeredHandlerMap
@ -105,6 +109,25 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
// Start runs the listener in a go routine
func (s *DefaultServer) Start() {
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
s.runtimePort = 53
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
s.fakeResolverWG.Add(1)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
hookID := s.filterDNSTraffic()
s.fakeResolverWG.Wait()
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
}()
return
}
if s.customAddress != nil {
s.runtimeIP = s.customAddress.Addr().String()
s.runtimePort = int(s.customAddress.Port())
@ -172,6 +195,10 @@ func (s *DefaultServer) Stop() {
log.Error(err)
}
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
}
err = s.stopListener()
if err != nil {
log.Error(err)
@ -235,12 +262,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be disabled, we stop the listener
// is the service should be disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if !update.ServiceEnable {
err := s.stopListener()
if err != nil {
log.Error(err)
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
s.fakeResolverWG.Done()
} else {
if err := s.stopListener(); err != nil {
log.Error(err)
}
}
} else if !s.listenerIsRunning {
s.Start()
@ -477,3 +507,59 @@ func (s *DefaultServer) upstreamCallbacks(
}
return
}
func (s *DefaultServer) filterDNSTraffic() string {
filter := s.wgInterface.GetFilter()
if filter == nil {
log.Error("can't set DNS filter, filter not initialized")
return ""
}
firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil {
firstLayerDecoder = layers.LayerTypeIPv6
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
log.Tracef("parse DNS request: %v", err)
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
}
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
// Calculate the last IP in the CIDR range
var endIP net.IP
for i := 0; i < len(network.IP); i++ {
endIP = append(endIP, network.IP[i]|^network.Mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
return net.IP(resultInt.Bytes()).String()
}

View File

@ -0,0 +1,31 @@
package dns
import (
"net"
"testing"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := getLastIPFromNetwork(ipnet, 1)
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@ -9,10 +9,9 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
)
@ -238,6 +237,7 @@ func TestUpdateDNSServer(t *testing.T) {
dnsServer.updateSerial = testCase.initSerial
// pretend we are running
dnsServer.listenerIsRunning = true
dnsServer.fakeResolverWG.Add(1)
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {

View File

@ -15,6 +15,15 @@ type PacketFilter interface {
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
@ -82,8 +91,8 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
return n, err
}
// SetFiltering sets packet filter to device
func (d *DeviceWrapper) SetFiltering(filter PacketFilter) {
// SetFilter sets packet filter to device
func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()

View File

@ -14,13 +14,6 @@ func TestDeviceWrapperRead(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
filter := mocks.NewMockPacketFilter(ctrl)
mockBufs := [][]byte{{}}
mockSizes := []int{0}
mockOffset := 0
t.Run("read ICMP", func(t *testing.T) {
ipLayer := &layers.IPv4{
Version: 4,
@ -46,6 +39,11 @@ func TestDeviceWrapperRead(t *testing.T) {
return
}
mockBufs := [][]byte{{}}
mockSizes := []int{0}
mockOffset := 0
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
bufs[0] = buffer.Bytes()
@ -95,7 +93,10 @@ func TestDeviceWrapperRead(t *testing.T) {
return
}
mockBufs := [][]byte{buffer.Bytes()}
mockBufs[0] = buffer.Bytes()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
wrapped := newDeviceWrapper(tun)
@ -138,10 +139,13 @@ func TestDeviceWrapperRead(t *testing.T) {
return
}
mockBufs = [][]byte{}
mockBufs := [][]byte{}
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter.EXPECT().DropOutput(gomock.Any()).Return(true)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun)
wrapped.filter = filter
@ -188,13 +192,15 @@ func TestDeviceWrapperRead(t *testing.T) {
mockSizes := []int{0}
mockOffset := 0
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
bufs[0] = buffer.Bytes()
sizes[0] = len(bufs[0])
return 1, nil
})
filter.EXPECT().DropInput(gomock.Any()).Return(true)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun)
wrapped.filter = filter

View File

@ -23,6 +23,7 @@ type WGIface struct {
configurer wGConfigurer
mu sync.Mutex
userspaceBind bool
filter PacketFilter
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@ -120,8 +121,8 @@ func (w *WGIface) Close() error {
return w.tun.Close()
}
// SetFiltering sets packet filters for the userspace impelemntation
func (w *WGIface) SetFiltering(filter PacketFilter) error {
// SetFilter sets packet filters for the userspace impelemntation
func (w *WGIface) SetFilter(filter PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()
@ -129,7 +130,25 @@ func (w *WGIface) SetFiltering(filter PacketFilter) error {
return fmt.Errorf("userspace packet filtering not handled on this device")
}
filter.SetNetwork(w.tun.address.Network)
w.tun.wrapper.SetFiltering(filter)
w.filter = filter
w.filter.SetNetwork(w.tun.address.Network)
w.tun.wrapper.SetFilter(filter)
return nil
}
// GetFilter returns packet filter used by interface if it uses userspace device implementation
func (w *WGIface) GetFilter() PacketFilter {
w.mu.Lock()
defer w.mu.Unlock()
return w.filter
}
// GetDevice to interact with raw device (with filtering)
func (w *WGIface) GetDevice() *DeviceWrapper {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.wrapper
}

7
iface/mocks/README.md Normal file
View File

@ -0,0 +1,7 @@
## Mocks
To generate (or refresh) mocks from iface package interfaces please install [mockgen](https://github.com/golang/mock).
Run this command to update PacketFilter mock:
```bash
mockgen -destination iface/mocks/filter.go -package mocks github.com/netbirdio/netbird/iface PacketFilter
```

View File

@ -34,21 +34,21 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// DropInput mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
}
// DropInput indicates an expected call of DropInput.
func (mr *MockPacketFilterMockRecorder) DropInput(arg0 interface{}) *gomock.Call {
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropOutput mocks base method.
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
@ -56,12 +56,40 @@ func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
return ret0
}
// DropOutput indicates an expected call of DropOutput.
func (mr *MockPacketFilterMockRecorder) DropOutput(arg0 interface{}) *gomock.Call {
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()

View File

@ -0,0 +1,87 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}