netbird/iface/device_wrapper.go

100 lines
2.3 KiB
Go
Raw Normal View History

package iface
import (
"net"
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
}
// DeviceWrapper to override Read or Write of packets
type DeviceWrapper struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceWrapper constructor function
func newDeviceWrapper(device tun.Device) *DeviceWrapper {
return &DeviceWrapper{
Device: device,
}
}
// Read wraps read method with filtering feature
func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
return n, nil
}
// Write wraps write method with filtering feature
func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf)
dropped++
}
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
}
// SetFilter sets packet filter to device
func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()
}