package iface

import (
	"encoding/hex"
	"fmt"
	"net"
	"os"
	"runtime"
	"strconv"
	"strings"
	"time"

	log "github.com/sirupsen/logrus"
	"golang.zx2c4.com/wireguard/device"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

	nbnet "github.com/netbirdio/netbird/util/net"
)

var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")

type wgUSPConfigurer struct {
	device     *device.Device
	deviceName string

	uapiListener net.Listener
}

func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
	wgCfg := &wgUSPConfigurer{
		device:     device,
		deviceName: deviceName,
	}
	wgCfg.startUAPI()
	return wgCfg
}

func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error {
	log.Debugf("adding Wireguard private key")
	key, err := wgtypes.ParseKey(privateKey)
	if err != nil {
		return err
	}
	fwmark := getFwmark()
	config := wgtypes.Config{
		PrivateKey:   &key,
		ReplacePeers: true,
		FirewallMark: &fwmark,
		ListenPort:   &port,
	}

	return c.device.IpcSet(toWgUserspaceString(config))
}

func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
	// parse allowed ips
	_, ipNet, err := net.ParseCIDR(allowedIps)
	if err != nil {
		return err
	}

	peerKeyParsed, err := wgtypes.ParseKey(peerKey)
	if err != nil {
		return err
	}
	peer := wgtypes.PeerConfig{
		PublicKey:                   peerKeyParsed,
		ReplaceAllowedIPs:           true,
		AllowedIPs:                  []net.IPNet{*ipNet},
		PersistentKeepaliveInterval: &keepAlive,
		PresharedKey:                preSharedKey,
		Endpoint:                    endpoint,
	}

	config := wgtypes.Config{
		Peers: []wgtypes.PeerConfig{peer},
	}

	return c.device.IpcSet(toWgUserspaceString(config))
}

func (c *wgUSPConfigurer) removePeer(peerKey string) error {
	peerKeyParsed, err := wgtypes.ParseKey(peerKey)
	if err != nil {
		return err
	}

	peer := wgtypes.PeerConfig{
		PublicKey: peerKeyParsed,
		Remove:    true,
	}

	config := wgtypes.Config{
		Peers: []wgtypes.PeerConfig{peer},
	}
	return c.device.IpcSet(toWgUserspaceString(config))
}

func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
	_, ipNet, err := net.ParseCIDR(allowedIP)
	if err != nil {
		return err
	}

	peerKeyParsed, err := wgtypes.ParseKey(peerKey)
	if err != nil {
		return err
	}
	peer := wgtypes.PeerConfig{
		PublicKey:         peerKeyParsed,
		UpdateOnly:        true,
		ReplaceAllowedIPs: false,
		AllowedIPs:        []net.IPNet{*ipNet},
	}

	config := wgtypes.Config{
		Peers: []wgtypes.PeerConfig{peer},
	}

	return c.device.IpcSet(toWgUserspaceString(config))
}

func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
	ipc, err := c.device.IpcGet()
	if err != nil {
		return err
	}

	peerKeyParsed, err := wgtypes.ParseKey(peerKey)
	if err != nil {
		return err
	}
	hexKey := hex.EncodeToString(peerKeyParsed[:])

	lines := strings.Split(ipc, "\n")

	peer := wgtypes.PeerConfig{
		PublicKey:         peerKeyParsed,
		UpdateOnly:        true,
		ReplaceAllowedIPs: true,
		AllowedIPs:        []net.IPNet{},
	}

	foundPeer := false
	removedAllowedIP := false
	for _, line := range lines {
		line = strings.TrimSpace(line)

		// If we're within the details of the found peer and encounter another public key,
		// this means we're starting another peer's details. So, reset the flag.
		if strings.HasPrefix(line, "public_key=") && foundPeer {
			foundPeer = false
		}

		// Identify the peer with the specific public key
		if line == fmt.Sprintf("public_key=%s", hexKey) {
			foundPeer = true
		}

		// If we're within the details of the found peer and find the specific allowed IP, skip this line
		if foundPeer && line == "allowed_ip="+ip {
			removedAllowedIP = true
			continue
		}

		// Append the line to the output string
		if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
			allowedIP := strings.TrimPrefix(line, "allowed_ip=")
			_, ipNet, err := net.ParseCIDR(allowedIP)
			if err != nil {
				return err
			}
			peer.AllowedIPs = append(peer.AllowedIPs, *ipNet)
		}
	}

	if !removedAllowedIP {
		return ErrAllowedIPNotFound
	}
	config := wgtypes.Config{
		Peers: []wgtypes.PeerConfig{peer},
	}
	return c.device.IpcSet(toWgUserspaceString(config))
}

// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *wgUSPConfigurer) startUAPI() {
	var err error
	t.uapiListener, err = openUAPI(t.deviceName)
	if err != nil {
		log.Errorf("failed to open uapi listener: %v", err)
		return
	}

	go func(uapi net.Listener) {
		for {
			uapiConn, uapiErr := uapi.Accept()
			if uapiErr != nil {
				log.Tracef("%s", uapiErr)
				return
			}
			go func() {
				t.device.IpcHandle(uapiConn)
			}()
		}
	}(t.uapiListener)
}

func (t *wgUSPConfigurer) close() {
	if t.uapiListener != nil {
		err := t.uapiListener.Close()
		if err != nil {
			log.Errorf("failed to close uapi listener: %v", err)
		}
	}

	if runtime.GOOS == "linux" {
		sockPath := "/var/run/wireguard/" + t.deviceName + ".sock"
		if _, statErr := os.Stat(sockPath); statErr == nil {
			_ = os.Remove(sockPath)
		}
	}
}

func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) {
	ipc, err := t.device.IpcGet()
	if err != nil {
		return WGStats{}, fmt.Errorf("ipc get: %w", err)
	}

	stats, err := findPeerInfo(ipc, peerKey, []string{
		"last_handshake_time_sec",
		"last_handshake_time_nsec",
		"tx_bytes",
		"rx_bytes",
	})
	if err != nil {
		return WGStats{}, fmt.Errorf("find peer info: %w", err)
	}

	sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
	if err != nil {
		return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
	}
	nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
	if err != nil {
		return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
	}
	txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
	if err != nil {
		return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
	}
	rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
	if err != nil {
		return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
	}

	return WGStats{
		LastHandshake: time.Unix(sec, nsec),
		TxBytes:       txBytes,
		RxBytes:       rxBytes,
	}, nil
}

func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
	peerKeyParsed, err := wgtypes.ParseKey(peerKey)
	if err != nil {
		return nil, fmt.Errorf("parse key: %w", err)
	}

	hexKey := hex.EncodeToString(peerKeyParsed[:])

	lines := strings.Split(ipcInput, "\n")

	configFound := map[string]string{}
	foundPeer := false
	for _, line := range lines {
		line = strings.TrimSpace(line)

		// If we're within the details of the found peer and encounter another public key,
		// this means we're starting another peer's details. So, stop.
		if strings.HasPrefix(line, "public_key=") && foundPeer {
			break
		}

		// Identify the peer with the specific public key
		if line == fmt.Sprintf("public_key=%s", hexKey) {
			foundPeer = true
		}

		for _, key := range searchConfigKeys {
			if foundPeer && strings.HasPrefix(line, key+"=") {
				v := strings.SplitN(line, "=", 2)
				configFound[v[0]] = v[1]
			}
		}
	}

	// todo: use multierr
	for _, key := range searchConfigKeys {
		if _, ok := configFound[key]; !ok {
			return configFound, fmt.Errorf("config key not found: %s", key)
		}
	}
	if !foundPeer {
		return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
	}

	return configFound, nil
}

func toWgUserspaceString(wgCfg wgtypes.Config) string {
	var sb strings.Builder
	if wgCfg.PrivateKey != nil {
		hexKey := hex.EncodeToString(wgCfg.PrivateKey[:])
		sb.WriteString(fmt.Sprintf("private_key=%s\n", hexKey))
	}

	if wgCfg.ListenPort != nil {
		sb.WriteString(fmt.Sprintf("listen_port=%d\n", *wgCfg.ListenPort))
	}

	if wgCfg.ReplacePeers {
		sb.WriteString("replace_peers=true\n")
	}

	if wgCfg.FirewallMark != nil {
		sb.WriteString(fmt.Sprintf("fwmark=%d\n", *wgCfg.FirewallMark))
	}

	for _, p := range wgCfg.Peers {
		hexKey := hex.EncodeToString(p.PublicKey[:])
		sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))

		if p.PresharedKey != nil {
			preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
			sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
		}

		if p.Remove {
			sb.WriteString("remove=true")
		}

		if p.ReplaceAllowedIPs {
			sb.WriteString("replace_allowed_ips=true\n")
		}

		for _, aip := range p.AllowedIPs {
			sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
		}

		if p.Endpoint != nil {
			sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
		}

		if p.PersistentKeepaliveInterval != nil {
			sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
		}
	}
	return sb.String()
}

func getFwmark() int {
	if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
		return nbnet.NetbirdFwmark
	}
	return 0
}