mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-16 18:21:24 +01:00
Feat fake dns address (#902)
Works only with userspace implementation: 1. Configure host to solve DNS requests via a fake DSN server address in the Netbird network. 2. Add to firewall catch rule for these DNS requests. 3. Resolve these DNS requests and respond by writing directly to wireguard device.
This commit is contained in:
parent
2c9583dfe1
commit
1d9feab2d9
@ -20,6 +20,8 @@ type Rule struct {
|
|||||||
dPort uint16
|
dPort uint16
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
comment string
|
||||||
|
|
||||||
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
|
@ -18,7 +18,7 @@ const layerTypeAll = 0
|
|||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
SetFiltering(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
@ -64,7 +64,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFiltering(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
@ -273,6 +273,12 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
return rule.drop
|
return rule.drop
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
|
// we ignore rule.drop and call this hook
|
||||||
|
if rule.udpHook != nil {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
|
||||||
if rule.sPort == 0 && rule.dPort == 0 {
|
if rule.sPort == 0 && rule.dPort == 0 {
|
||||||
return rule.drop
|
return rule.drop
|
||||||
}
|
}
|
||||||
@ -296,3 +302,58 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
|||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||||
m.wgNetwork = network
|
m.wgNetwork = network
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
|
//
|
||||||
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
|
func (m *Manager) AddUDPPacketHook(
|
||||||
|
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
||||||
|
) string {
|
||||||
|
r := Rule{
|
||||||
|
id: uuid.New().String(),
|
||||||
|
ip: ip,
|
||||||
|
protoLayer: layers.LayerTypeUDP,
|
||||||
|
dPort: dPort,
|
||||||
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
|
direction: fw.RuleDirectionOUT,
|
||||||
|
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
||||||
|
udpHook: hook,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip.To4() != nil {
|
||||||
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
var toUpdate []Rule
|
||||||
|
if in {
|
||||||
|
r.direction = fw.RuleDirectionIN
|
||||||
|
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
||||||
|
toUpdate = m.incomingRules
|
||||||
|
} else {
|
||||||
|
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
||||||
|
toUpdate = m.outgoingRules
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range toUpdate {
|
||||||
|
m.rulesIndex[toUpdate[i].id] = i
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePacketHook removes packet hook by given ID
|
||||||
|
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||||
|
for _, r := range m.incomingRules {
|
||||||
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, r := range m.outgoingRules {
|
||||||
|
if r.id == hookID {
|
||||||
|
return m.DeleteRule(&r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("hook with given id not found")
|
||||||
|
}
|
||||||
|
@ -15,19 +15,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilteringFunc func(iface.PacketFilter) error
|
SetFilterFunc func(iface.PacketFilter) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFiltering(iface iface.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
|
||||||
if i.SetFilteringFunc == nil {
|
if i.SetFilterFunc == nil {
|
||||||
return fmt.Errorf("not implemented")
|
return fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
return i.SetFilteringFunc(iface)
|
return i.SetFilterFunc(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManagerCreate(t *testing.T) {
|
func TestManagerCreate(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock)
|
||||||
@ -42,10 +42,10 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManagerAddFiltering(t *testing.T) {
|
func TestManagerAddFiltering(t *testing.T) {
|
||||||
isSetFilteringCalled := false
|
isSetFilterCalled := false
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilteringFunc: func(iface.PacketFilter) error {
|
SetFilterFunc: func(iface.PacketFilter) error {
|
||||||
isSetFilteringCalled = true
|
isSetFilterCalled = true
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -74,15 +74,15 @@ func TestManagerAddFiltering(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isSetFilteringCalled {
|
if !isSetFilterCalled {
|
||||||
t.Error("SetFiltering was not called")
|
t.Error("SetFilter was not called")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManagerDeleteRule(t *testing.T) {
|
func TestManagerDeleteRule(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock)
|
||||||
@ -138,9 +138,97 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddUDPPacketHook(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in bool
|
||||||
|
expDir fw.RuleDirection
|
||||||
|
ip net.IP
|
||||||
|
dPort uint16
|
||||||
|
hook func([]byte) bool
|
||||||
|
expectedID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
|
in: false,
|
||||||
|
expDir: fw.RuleDirectionOUT,
|
||||||
|
ip: net.IPv4(10, 168, 0, 1),
|
||||||
|
dPort: 8000,
|
||||||
|
hook: func([]byte) bool { return true },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Test Incoming UDP Packet Hook",
|
||||||
|
in: true,
|
||||||
|
expDir: fw.RuleDirectionIN,
|
||||||
|
ip: net.IPv6loopback,
|
||||||
|
dPort: 9000,
|
||||||
|
hook: func([]byte) bool { return false },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
manager := &Manager{
|
||||||
|
incomingRules: []Rule{},
|
||||||
|
outgoingRules: []Rule{},
|
||||||
|
rulesIndex: make(map[string]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
|
var addedRule Rule
|
||||||
|
if tt.in {
|
||||||
|
if len(manager.incomingRules) != 1 {
|
||||||
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addedRule = manager.incomingRules[0]
|
||||||
|
} else {
|
||||||
|
if len(manager.outgoingRules) != 1 {
|
||||||
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addedRule = manager.outgoingRules[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.ip.Equal(addedRule.ip) {
|
||||||
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.dPort != addedRule.dPort {
|
||||||
|
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||||
|
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.expDir != addedRule.direction {
|
||||||
|
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if addedRule.udpHook == nil {
|
||||||
|
t.Errorf("expected udpHook to be set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure rulesIndex is correctly updated
|
||||||
|
index, ok := manager.rulesIndex[addedRule.id]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("expected rule to be in rulesIndex")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if index != 0 {
|
||||||
|
t.Errorf("expected rule index to be 0, got %d", index)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock)
|
||||||
@ -175,7 +263,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
|
|
||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock)
|
||||||
@ -239,12 +327,56 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRemovePacketHook tests the functionality of the RemovePacketHook method
|
||||||
|
func TestRemovePacketHook(t *testing.T) {
|
||||||
|
// creating mock iface
|
||||||
|
iface := &IFaceMock{
|
||||||
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
// creating manager instance
|
||||||
|
manager, err := Create(iface)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a UDP packet hook
|
||||||
|
hookFunc := func(data []byte) bool { return true }
|
||||||
|
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
||||||
|
|
||||||
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
|
found := false
|
||||||
|
for _, rule := range manager.outgoingRules {
|
||||||
|
if rule.id == hookID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("The hook was not added properly.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now remove the packet hook
|
||||||
|
err = manager.RemovePacketHook(hookID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to remove hook: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||||
|
for _, rule := range manager.outgoingRules {
|
||||||
|
if rule.id == hookID {
|
||||||
|
t.Fatalf("The hook was not removed properly.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
SetFilterFunc: func(iface.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -19,7 +19,7 @@ type IFaceMapper interface {
|
|||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() iface.WGAddress
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFiltering(iface.PacketFilter) error
|
SetFilter(iface.PacketFilter) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
|
@ -35,7 +35,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||||
// iface.EXPECT().Name().Return("lo")
|
// iface.EXPECT().Name().Return("lo")
|
||||||
iface.EXPECT().SetFiltering(gomock.Any())
|
iface.EXPECT().SetFilter(gomock.Any())
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(iface)
|
acl, err := Create(iface)
|
||||||
@ -311,7 +311,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||||
// iface.EXPECT().Name().Return("lo")
|
// iface.EXPECT().Name().Return("lo")
|
||||||
iface.EXPECT().SetFiltering(gomock.Any())
|
iface.EXPECT().SetFilter(gomock.Any())
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
acl, err := Create(iface)
|
acl, err := Create(iface)
|
||||||
|
@ -76,16 +76,16 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockIFaceMapper)(nil).Name))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFiltering mocks base method.
|
// SetFilter mocks base method.
|
||||||
func (m *MockIFaceMapper) SetFiltering(arg0 iface.PacketFilter) error {
|
func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "SetFiltering", arg0)
|
ret := m.ctrl.Call(m, "SetFilter", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFiltering indicates an expected call of SetFiltering.
|
// SetFilter indicates an expected call of SetFilter.
|
||||||
func (mr *MockIFaceMapperMockRecorder) SetFiltering(arg0 interface{}) *gomock.Call {
|
func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFiltering", reflect.TypeOf((*MockIFaceMapper)(nil).SetFiltering), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0)
|
||||||
}
|
}
|
||||||
|
103
client/internal/dns/response_writer.go
Normal file
103
client/internal/dns/response_writer.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
type responseWriter struct {
|
||||||
|
local net.Addr
|
||||||
|
remote net.Addr
|
||||||
|
packet gopacket.Packet
|
||||||
|
device tun.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the net.Addr of the server
|
||||||
|
func (r *responseWriter) LocalAddr() net.Addr {
|
||||||
|
return r.local
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the net.Addr of the client that sent the current request.
|
||||||
|
func (r *responseWriter) RemoteAddr() net.Addr {
|
||||||
|
return r.remote
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMsg writes a reply back to the client.
|
||||||
|
func (r *responseWriter) WriteMsg(msg *dns.Msg) error {
|
||||||
|
buff, err := msg.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = r.Write(buff)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes a raw buffer back to the client.
|
||||||
|
func (r *responseWriter) Write(data []byte) (int, error) {
|
||||||
|
var ip gopacket.SerializableLayer
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
|
udpLayer := r.packet.Layer(layers.LayerTypeUDP)
|
||||||
|
udp := udpLayer.(*layers.UDP)
|
||||||
|
|
||||||
|
// Swap the source and destination addresses for the response
|
||||||
|
udp.SrcPort, udp.DstPort = udp.DstPort, udp.SrcPort
|
||||||
|
|
||||||
|
// Check if it's an IPv4 packet
|
||||||
|
if ipv4Layer := r.packet.Layer(layers.LayerTypeIPv4); ipv4Layer != nil {
|
||||||
|
ipv4 := ipv4Layer.(*layers.IPv4)
|
||||||
|
ipv4.SrcIP, ipv4.DstIP = ipv4.DstIP, ipv4.SrcIP
|
||||||
|
ip = ipv4
|
||||||
|
} else if ipv6Layer := r.packet.Layer(layers.LayerTypeIPv6); ipv6Layer != nil {
|
||||||
|
ipv6 := ipv6Layer.(*layers.IPv6)
|
||||||
|
ipv6.SrcIP, ipv6.DstIP = ipv6.DstIP, ipv6.SrcIP
|
||||||
|
ip = ipv6
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(ip.(gopacket.NetworkLayer)); err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to set network layer for checksum: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize the packet
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
options := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := gopacket.Payload(data)
|
||||||
|
err := gopacket.SerializeLayers(buffer, options, ip, udp, payload)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to serialize packet: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
send := buffer.Bytes()
|
||||||
|
sendBuffer := make([]byte, 40, len(send)+40)
|
||||||
|
sendBuffer = append(sendBuffer, send...)
|
||||||
|
|
||||||
|
return r.device.Write([][]byte{sendBuffer}, 40)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (r *responseWriter) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TsigStatus returns the status of the Tsig.
|
||||||
|
func (r *responseWriter) TsigStatus() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TsigTimersOnly sets the tsig timers only boolean.
|
||||||
|
func (r *responseWriter) TsigTimersOnly(bool) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack lets the caller take over the connection.
|
||||||
|
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||||
|
func (r *responseWriter) Hijack() {
|
||||||
|
}
|
93
client/internal/dns/response_writer_test.go
Normal file
93
client/internal/dns/response_writer_test.go
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/iface/mocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseWriterLocalAddr(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
device := mocks.NewMockDevice(ctrl)
|
||||||
|
device.EXPECT().Write(gomock.Any(), gomock.Any())
|
||||||
|
|
||||||
|
request := &dns.Msg{
|
||||||
|
Question: []dns.Question{{
|
||||||
|
Name: "google.com.",
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.TypeA,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
replyMessage := &dns.Msg{}
|
||||||
|
replyMessage.SetReply(request)
|
||||||
|
replyMessage.RecursionAvailable = true
|
||||||
|
replyMessage.Rcode = dns.RcodeSuccess
|
||||||
|
replyMessage.Answer = []dns.RR{
|
||||||
|
&dns.A{
|
||||||
|
A: net.IPv4(8, 8, 8, 8),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
SrcIP: net.IPv4(127, 0, 0, 1),
|
||||||
|
DstIP: net.IPv4(127, 0, 0, 2),
|
||||||
|
}
|
||||||
|
udp := &layers.UDP{
|
||||||
|
DstPort: 53,
|
||||||
|
SrcPort: 45223,
|
||||||
|
}
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(ipv4); err != nil {
|
||||||
|
t.Error("failed to set network layer for checksum")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize the packet
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
options := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
requestData, err := request.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("got an error while packing the request message, error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload := gopacket.Payload(requestData)
|
||||||
|
|
||||||
|
if err := gopacket.SerializeLayers(buffer, options, ipv4, udp, payload); err != nil {
|
||||||
|
t.Errorf("failed to serialize packet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &responseWriter{
|
||||||
|
local: &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 55223,
|
||||||
|
},
|
||||||
|
remote: &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
packet: gopacket.NewPacket(
|
||||||
|
buffer.Bytes(),
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
gopacket.Default,
|
||||||
|
),
|
||||||
|
device: device,
|
||||||
|
}
|
||||||
|
if err := rw.WriteMsg(replyMessage); err != nil {
|
||||||
|
t.Errorf("got an error while writing the local resolver response, error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
@ -5,12 +5,15 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -33,6 +36,7 @@ type DefaultServer struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
fakeResolverWG sync.WaitGroup
|
||||||
server *dns.Server
|
server *dns.Server
|
||||||
dnsMux *dns.ServeMux
|
dnsMux *dns.ServeMux
|
||||||
dnsMuxMap registeredHandlerMap
|
dnsMuxMap registeredHandlerMap
|
||||||
@ -105,6 +109,25 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
|||||||
|
|
||||||
// Start runs the listener in a go routine
|
// Start runs the listener in a go routine
|
||||||
func (s *DefaultServer) Start() {
|
func (s *DefaultServer) Start() {
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||||
|
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||||
|
s.runtimePort = 53
|
||||||
|
|
||||||
|
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||||
|
s.fakeResolverWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
s.setListenerStatus(true)
|
||||||
|
defer s.setListenerStatus(false)
|
||||||
|
|
||||||
|
hookID := s.filterDNSTraffic()
|
||||||
|
s.fakeResolverWG.Wait()
|
||||||
|
if err := s.wgInterface.GetFilter().RemovePacketHook(hookID); err != nil {
|
||||||
|
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if s.customAddress != nil {
|
if s.customAddress != nil {
|
||||||
s.runtimeIP = s.customAddress.Addr().String()
|
s.runtimeIP = s.customAddress.Addr().String()
|
||||||
s.runtimePort = int(s.customAddress.Port())
|
s.runtimePort = int(s.customAddress.Port())
|
||||||
@ -172,6 +195,10 @@ func (s *DefaultServer) Stop() {
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||||
|
s.fakeResolverWG.Done()
|
||||||
|
}
|
||||||
|
|
||||||
err = s.stopListener()
|
err = s.stopListener()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -235,12 +262,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be disabled, we stop the listener
|
// is the service should be disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if !update.ServiceEnable {
|
if !update.ServiceEnable {
|
||||||
err := s.stopListener()
|
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||||
if err != nil {
|
s.fakeResolverWG.Done()
|
||||||
log.Error(err)
|
} else {
|
||||||
|
if err := s.stopListener(); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if !s.listenerIsRunning {
|
} else if !s.listenerIsRunning {
|
||||||
s.Start()
|
s.Start()
|
||||||
@ -477,3 +507,59 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) filterDNSTraffic() string {
|
||||||
|
filter := s.wgInterface.GetFilter()
|
||||||
|
if filter == nil {
|
||||||
|
log.Error("can't set DNS filter, filter not initialized")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLayerDecoder := layers.LayerTypeIPv4
|
||||||
|
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||||
|
firstLayerDecoder = layers.LayerTypeIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
hook := func(packetData []byte) bool {
|
||||||
|
// Decode the packet
|
||||||
|
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||||
|
|
||||||
|
// Get the UDP layer
|
||||||
|
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||||
|
udp := udpLayer.(*layers.UDP)
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(udp.Payload); err != nil {
|
||||||
|
log.Tracef("parse DNS request: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := responseWriter{
|
||||||
|
packet: packet,
|
||||||
|
device: s.wgInterface.GetDevice().Device,
|
||||||
|
}
|
||||||
|
go s.dnsMux.ServeDNS(&writer, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||||
|
// Calculate the last IP in the CIDR range
|
||||||
|
var endIP net.IP
|
||||||
|
for i := 0; i < len(network.IP); i++ {
|
||||||
|
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to big.Int
|
||||||
|
endInt := big.NewInt(0)
|
||||||
|
endInt.SetBytes(endIP)
|
||||||
|
|
||||||
|
// subtract fromEnd from the last ip
|
||||||
|
fromEndBig := big.NewInt(int64(fromEnd))
|
||||||
|
resultInt := big.NewInt(0)
|
||||||
|
resultInt.Sub(endInt, fromEndBig)
|
||||||
|
|
||||||
|
return net.IP(resultInt.Bytes()).String()
|
||||||
|
}
|
||||||
|
31
client/internal/dns/server_nonandroid_test.go
Normal file
31
client/internal/dns/server_nonandroid_test.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addr string
|
||||||
|
ip string
|
||||||
|
}{
|
||||||
|
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||||
|
{"192.168.0.0/30", "192.168.0.2"},
|
||||||
|
{"192.168.0.0/16", "192.168.255.254"},
|
||||||
|
{"192.168.0.0/24", "192.168.0.254"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error parsing CIDR: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||||
|
if lastIP != tt.ip {
|
||||||
|
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -9,10 +9,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
@ -238,6 +237,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
// pretend we are running
|
// pretend we are running
|
||||||
dnsServer.listenerIsRunning = true
|
dnsServer.listenerIsRunning = true
|
||||||
|
dnsServer.fakeResolverWG.Add(1)
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -15,6 +15,15 @@ type PacketFilter interface {
|
|||||||
// DropIncoming filter incoming packets from external sources to host
|
// DropIncoming filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte) bool
|
DropIncoming(packetData []byte) bool
|
||||||
|
|
||||||
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
|
//
|
||||||
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
|
// Hook function receives raw network packet data as argument.
|
||||||
|
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
|
// RemovePacketHook removes hook by ID
|
||||||
|
RemovePacketHook(hookID string) error
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
// SetNetwork of the wireguard interface to which filtering applied
|
||||||
SetNetwork(*net.IPNet)
|
SetNetwork(*net.IPNet)
|
||||||
}
|
}
|
||||||
@ -82,8 +91,8 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFiltering sets packet filter to device
|
// SetFilter sets packet filter to device
|
||||||
func (d *DeviceWrapper) SetFiltering(filter PacketFilter) {
|
func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
|
||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
d.filter = filter
|
d.filter = filter
|
||||||
d.mutex.Unlock()
|
d.mutex.Unlock()
|
||||||
|
@ -14,13 +14,6 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
|
|
||||||
tun := mocks.NewMockDevice(ctrl)
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
|
||||||
|
|
||||||
mockBufs := [][]byte{{}}
|
|
||||||
mockSizes := []int{0}
|
|
||||||
mockOffset := 0
|
|
||||||
|
|
||||||
t.Run("read ICMP", func(t *testing.T) {
|
t.Run("read ICMP", func(t *testing.T) {
|
||||||
ipLayer := &layers.IPv4{
|
ipLayer := &layers.IPv4{
|
||||||
Version: 4,
|
Version: 4,
|
||||||
@ -46,6 +39,11 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mockBufs := [][]byte{{}}
|
||||||
|
mockSizes := []int{0}
|
||||||
|
mockOffset := 0
|
||||||
|
|
||||||
|
tun := mocks.NewMockDevice(ctrl)
|
||||||
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
|
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
|
||||||
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||||
bufs[0] = buffer.Bytes()
|
bufs[0] = buffer.Bytes()
|
||||||
@ -95,7 +93,10 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mockBufs := [][]byte{buffer.Bytes()}
|
||||||
|
|
||||||
mockBufs[0] = buffer.Bytes()
|
mockBufs[0] = buffer.Bytes()
|
||||||
|
tun := mocks.NewMockDevice(ctrl)
|
||||||
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
|
||||||
|
|
||||||
wrapped := newDeviceWrapper(tun)
|
wrapped := newDeviceWrapper(tun)
|
||||||
@ -138,10 +139,13 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mockBufs = [][]byte{}
|
mockBufs := [][]byte{}
|
||||||
|
|
||||||
|
tun := mocks.NewMockDevice(ctrl)
|
||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
filter.EXPECT().DropOutput(gomock.Any()).Return(true)
|
|
||||||
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
|
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceWrapper(tun)
|
wrapped := newDeviceWrapper(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@ -188,13 +192,15 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
mockSizes := []int{0}
|
mockSizes := []int{0}
|
||||||
mockOffset := 0
|
mockOffset := 0
|
||||||
|
|
||||||
|
tun := mocks.NewMockDevice(ctrl)
|
||||||
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
|
tun.EXPECT().Read(mockBufs, mockSizes, mockOffset).
|
||||||
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||||
bufs[0] = buffer.Bytes()
|
bufs[0] = buffer.Bytes()
|
||||||
sizes[0] = len(bufs[0])
|
sizes[0] = len(bufs[0])
|
||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter.EXPECT().DropInput(gomock.Any()).Return(true)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
|
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceWrapper(tun)
|
wrapped := newDeviceWrapper(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
@ -23,6 +23,7 @@ type WGIface struct {
|
|||||||
configurer wGConfigurer
|
configurer wGConfigurer
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
userspaceBind bool
|
userspaceBind bool
|
||||||
|
filter PacketFilter
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||||
@ -120,8 +121,8 @@ func (w *WGIface) Close() error {
|
|||||||
return w.tun.Close()
|
return w.tun.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFiltering sets packet filters for the userspace impelemntation
|
// SetFilter sets packet filters for the userspace impelemntation
|
||||||
func (w *WGIface) SetFiltering(filter PacketFilter) error {
|
func (w *WGIface) SetFilter(filter PacketFilter) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@ -129,7 +130,25 @@ func (w *WGIface) SetFiltering(filter PacketFilter) error {
|
|||||||
return fmt.Errorf("userspace packet filtering not handled on this device")
|
return fmt.Errorf("userspace packet filtering not handled on this device")
|
||||||
}
|
}
|
||||||
|
|
||||||
filter.SetNetwork(w.tun.address.Network)
|
w.filter = filter
|
||||||
w.tun.wrapper.SetFiltering(filter)
|
w.filter.SetNetwork(w.tun.address.Network)
|
||||||
|
|
||||||
|
w.tun.wrapper.SetFilter(filter)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFilter returns packet filter used by interface if it uses userspace device implementation
|
||||||
|
func (w *WGIface) GetFilter() PacketFilter {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
return w.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevice to interact with raw device (with filtering)
|
||||||
|
func (w *WGIface) GetDevice() *DeviceWrapper {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
return w.tun.wrapper
|
||||||
|
}
|
||||||
|
7
iface/mocks/README.md
Normal file
7
iface/mocks/README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
## Mocks
|
||||||
|
|
||||||
|
To generate (or refresh) mocks from iface package interfaces please install [mockgen](https://github.com/golang/mock).
|
||||||
|
Run this command to update PacketFilter mock:
|
||||||
|
```bash
|
||||||
|
mockgen -destination iface/mocks/filter.go -package mocks github.com/netbirdio/netbird/iface PacketFilter
|
||||||
|
```
|
@ -34,21 +34,21 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropInput mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(string)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropInput indicates an expected call of DropInput.
|
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropInput(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutput mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||||
@ -56,12 +56,40 @@ func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
|||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutput indicates an expected call of DropOutput.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutput(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DropOutgoing mocks base method.
|
||||||
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePacketHook mocks base method.
|
||||||
|
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
// SetNetwork mocks base method.
|
||||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
87
iface/mocks/iface/mocks/filter.go
Normal file
87
iface/mocks/iface/mocks/filter.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
|
// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
|
||||||
|
|
||||||
|
// Package mocks is a generated GoMock package.
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
net "net"
|
||||||
|
reflect "reflect"
|
||||||
|
|
||||||
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockPacketFilter is a mock of PacketFilter interface.
|
||||||
|
type MockPacketFilter struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockPacketFilterMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||||
|
type MockPacketFilterMockRecorder struct {
|
||||||
|
mock *MockPacketFilter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockPacketFilter creates a new mock instance.
|
||||||
|
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||||
|
mock := &MockPacketFilter{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUDPPacketHook mocks base method.
|
||||||
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropIncoming mocks base method.
|
||||||
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropOutgoing mocks base method.
|
||||||
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetwork mocks base method.
|
||||||
|
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNetwork indicates an expected call of SetNetwork.
|
||||||
|
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user