package dns

import (
	"fmt"
	"math/big"
	"net"
	"sync"

	"github.com/google/gopacket"
	"github.com/google/gopacket/layers"
	"github.com/miekg/dns"
	log "github.com/sirupsen/logrus"
)

type ServiceViaMemory struct {
	wgInterface       WGIface
	dnsMux            *dns.ServeMux
	runtimeIP         string
	runtimePort       int
	udpFilterHookID   string
	listenerIsRunning bool
	listenerFlagLock  sync.Mutex
}

func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
	s := &ServiceViaMemory{
		wgInterface: wgIface,
		dnsMux:      dns.NewServeMux(),

		runtimeIP:   getLastIPFromNetwork(wgIface.Address().Network, 1),
		runtimePort: defaultPort,
	}
	return s
}

func (s *ServiceViaMemory) Listen() error {
	s.listenerFlagLock.Lock()
	defer s.listenerFlagLock.Unlock()

	if s.listenerIsRunning {
		return nil
	}

	var err error
	s.udpFilterHookID, err = s.filterDNSTraffic()
	if err != nil {
		return fmt.Errorf("filter dns traffice: %w", err)
	}
	s.listenerIsRunning = true

	log.Debugf("dns service listening on: %s", s.RuntimeIP())
	return nil
}

func (s *ServiceViaMemory) Stop() {
	s.listenerFlagLock.Lock()
	defer s.listenerFlagLock.Unlock()

	if !s.listenerIsRunning {
		return
	}

	if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
		log.Errorf("unable to remove DNS packet hook: %s", err)
	}

	s.listenerIsRunning = false
}

func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
	s.dnsMux.Handle(pattern, handler)
}

func (s *ServiceViaMemory) DeregisterMux(pattern string) {
	s.dnsMux.HandleRemove(pattern)
}

func (s *ServiceViaMemory) RuntimePort() int {
	return s.runtimePort
}

func (s *ServiceViaMemory) RuntimeIP() string {
	return s.runtimeIP
}

func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
	filter := s.wgInterface.GetFilter()
	if filter == nil {
		return "", fmt.Errorf("can't set DNS filter, filter not initialized")
	}

	firstLayerDecoder := layers.LayerTypeIPv4
	if s.wgInterface.Address().Network.IP.To4() == nil {
		firstLayerDecoder = layers.LayerTypeIPv6
	}

	hook := func(packetData []byte) bool {
		// Decode the packet
		packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)

		// Get the UDP layer
		udpLayer := packet.Layer(layers.LayerTypeUDP)
		udp := udpLayer.(*layers.UDP)

		msg := new(dns.Msg)
		if err := msg.Unpack(udp.Payload); err != nil {
			log.Tracef("parse DNS request: %v", err)
			return true
		}

		writer := responseWriter{
			packet: packet,
			device: s.wgInterface.GetDevice().Device,
		}
		go s.dnsMux.ServeDNS(&writer, msg)
		return true
	}

	return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
}

func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
	// Calculate the last IP in the CIDR range
	var endIP net.IP
	for i := 0; i < len(network.IP); i++ {
		endIP = append(endIP, network.IP[i]|^network.Mask[i])
	}

	// convert to big.Int
	endInt := big.NewInt(0)
	endInt.SetBytes(endIP)

	// subtract fromEnd from the last ip
	fromEndBig := big.NewInt(int64(fromEnd))
	resultInt := big.NewInt(0)
	resultInt.Sub(endInt, fromEndBig)

	return net.IP(resultInt.Bytes()).String()
}