package iface

import (
	"net"
	"sync"

	"golang.zx2c4.com/wireguard/tun"
)

// PacketFilter interface for firewall abilities
type PacketFilter interface {
	// DropOutgoing filter outgoing packets from host to external destinations
	DropOutgoing(packetData []byte) bool

	// DropIncoming filter incoming packets from external sources to host
	DropIncoming(packetData []byte) bool

	// AddUDPPacketHook calls hook when UDP packet from given direction matched
	//
	// Hook function returns flag which indicates should be the matched package dropped or not.
	// Hook function receives raw network packet data as argument.
	AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string

	// RemovePacketHook removes hook by ID
	RemovePacketHook(hookID string) error

	// SetNetwork of the wireguard interface to which filtering applied
	SetNetwork(*net.IPNet)
}

// DeviceWrapper to override Read or Write of packets
type DeviceWrapper struct {
	tun.Device
	filter PacketFilter
	mutex  sync.RWMutex
}

// newDeviceWrapper constructor function
func newDeviceWrapper(device tun.Device) *DeviceWrapper {
	return &DeviceWrapper{
		Device: device,
	}
}

// Read wraps read method with filtering feature
func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
	if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
		return 0, err
	}
	d.mutex.RLock()
	filter := d.filter
	d.mutex.RUnlock()

	if filter == nil {
		return
	}

	for i := 0; i < n; i++ {
		if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
			bufs = append(bufs[:i], bufs[i+1:]...)
			sizes = append(sizes[:i], sizes[i+1:]...)
			n--
			i--
		}
	}

	return n, nil
}

// Write wraps write method with filtering feature
func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
	d.mutex.RLock()
	filter := d.filter
	d.mutex.RUnlock()

	if filter == nil {
		return d.Device.Write(bufs, offset)
	}

	filteredBufs := make([][]byte, 0, len(bufs))
	dropped := 0
	for _, buf := range bufs {
		if !filter.DropIncoming(buf[offset:]) {
			filteredBufs = append(filteredBufs, buf)
			dropped++
		}
	}

	n, err := d.Device.Write(filteredBufs, offset)
	n += dropped
	return n, err
}

// SetFilter sets packet filter to device
func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
	d.mutex.Lock()
	d.filter = filter
	d.mutex.Unlock()
}