//go:build (linux && !android) || freebsd

package dns

import (
	"bytes"
	"fmt"
	"net/netip"
	"os"
	"strings"
	"time"

	log "github.com/sirupsen/logrus"

	"github.com/netbirdio/netbird/client/internal/statemanager"
)

const (
	fileGeneratedResolvConfContentHeader         = "# Generated by NetBird"
	fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"

	fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"

	fileMaxLineCharsLimit        = 256
	fileMaxNumberOfSearchDomains = 6
)

const (
	dnsFailoverTimeout  = 4 * time.Second
	dnsFailoverAttempts = 1
)

type fileConfigurator struct {
	repair *repair

	originalPerms  os.FileMode
	nbNameserverIP string
}

func newFileConfigurator() (*fileConfigurator, error) {
	fc := &fileConfigurator{}
	fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
	return fc, nil
}

func (f *fileConfigurator) supportCustomPort() bool {
	return false
}

func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
	backupFileExist := f.isBackupFileExist()
	if !config.RouteAll {
		if backupFileExist {
			f.repair.stopWatchFileChanges()
			err := f.restore()
			if err != nil {
				return fmt.Errorf("restoring the original resolv.conf file return err: %w", err)
			}
		}
		return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
	}

	if !backupFileExist {
		err := f.backup()
		if err != nil {
			return fmt.Errorf("unable to backup the resolv.conf file: %w", err)
		}
	}

	nbSearchDomains := searchDomains(config)
	f.nbNameserverIP = config.ServerIP

	resolvConf, err := parseBackupResolvConf()
	if err != nil {
		log.Errorf("could not read original search domains from %s: %s", fileDefaultResolvConfBackupLocation, err)
	}

	f.repair.stopWatchFileChanges()

	err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
	if err != nil {
		return err
	}
	f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
	return nil
}

func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
	searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
	nameServers := generateNsList(nbNameserverIP, cfg)

	options := prepareOptionsWithTimeout(cfg.others, int(dnsFailoverTimeout.Seconds()), dnsFailoverAttempts)
	buf := prepareResolvConfContent(
		searchDomainList,
		nameServers,
		options)

	log.Debugf("creating managed file %s", defaultResolvConfPath)
	err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
	if err != nil {
		restoreErr := f.restore()
		if restoreErr != nil {
			log.Errorf("attempt to restore default file failed with error: %s", err)
		}
		return fmt.Errorf("creating resolver file %s. Error: %w", defaultResolvConfPath, err)
	}

	log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)

	// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
	if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
		log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
	}

	return nil
}

func (f *fileConfigurator) restoreHostDNS() error {
	f.repair.stopWatchFileChanges()
	return f.restore()
}

func (f *fileConfigurator) backup() error {
	stats, err := os.Stat(defaultResolvConfPath)
	if err != nil {
		return fmt.Errorf("checking stats for %s file. Error: %w", defaultResolvConfPath, err)
	}

	f.originalPerms = stats.Mode()

	err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation)
	if err != nil {
		return fmt.Errorf("backing up %s: %w", defaultResolvConfPath, err)
	}
	return nil
}

func (f *fileConfigurator) restore() error {
	err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP)
	if err != nil {
		log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
	}

	err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
	if err != nil {
		return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
	}

	return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}

func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
	resolvConf, err := parseDefaultResolvConf()
	if err != nil {
		return fmt.Errorf("parse current resolv.conf: %w", err)
	}

	// no current nameservers set -> restore
	if len(resolvConf.nameServers) == 0 {
		return restoreResolvConfFile()
	}

	currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
	// not a valid first nameserver -> restore
	if err != nil {
		log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
		return restoreResolvConfFile()
	}

	// current address is still netbird's non-available dns address -> restore
	// comparing parsed addresses only, to remove ambiguity
	if currentDNSAddress.String() == storedDNSAddress.String() {
		return restoreResolvConfFile()
	}

	log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
	return nil
}

func (f *fileConfigurator) isBackupFileExist() bool {
	_, err := os.Stat(fileDefaultResolvConfBackupLocation)
	return err == nil
}

func restoreResolvConfFile() error {
	log.Debugf("restoring unclean shutdown: restoring %s from %s", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation)

	if err := copyFile(fileUncleanShutdownResolvConfLocation, defaultResolvConfPath); err != nil {
		return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
	}

	return nil
}

// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
	ns := make([]string, 1, len(cfg.nameServers)+1)
	ns[0] = nbNameserverIP
	for _, cfgNs := range cfg.nameServers {
		if nbNameserverIP != cfgNs {
			ns = append(ns, cfgNs)
		}
	}
	return ns
}

func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
	var buf bytes.Buffer
	buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)

	for _, cfgLine := range others {
		buf.WriteString(cfgLine)
		buf.WriteString("\n")
	}

	if len(searchDomains) > 0 {
		buf.WriteString("search ")
		buf.WriteString(strings.Join(searchDomains, " "))
		buf.WriteString("\n")
	}

	for _, ns := range nameServers {
		buf.WriteString("nameserver ")
		buf.WriteString(ns)
		buf.WriteString("\n")
	}
	return buf
}

func searchDomains(config HostDNSConfig) []string {
	listOfDomains := make([]string, 0)
	for _, dConf := range config.Domains {
		if dConf.MatchOnly || dConf.Disabled {
			continue
		}

		listOfDomains = append(listOfDomains, dConf.Domain)
	}
	return listOfDomains
}

// merge search Domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
	lineSize := len("search")
	searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))

	lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
	_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)

	return searchDomainsList
}

// validateAndFillSearchDomains checks if the search Domains list is not too long and if the line is not too long
// extend s slice with vs elements
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
	for _, sd := range vs {
		duplicated := false
		for _, fs := range *s {
			if fs == sd {
				duplicated = true
				break
			}

		}

		if duplicated {
			continue
		}

		tmpCharsNumber := initialLineChars + 1 + len(sd)
		if tmpCharsNumber > fileMaxLineCharsLimit {
			// lets log all skipped Domains
			log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
			continue
		}

		initialLineChars = tmpCharsNumber

		if len(*s) >= fileMaxNumberOfSearchDomains {
			// lets log all skipped Domains
			log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
			continue
		}
		*s = append(*s, sd)
	}

	return initialLineChars
}

func copyFile(src, dest string) error {
	stats, err := os.Stat(src)
	if err != nil {
		return fmt.Errorf("checking stats for %s file when copying it. Error: %s", src, err)
	}

	bytesRead, err := os.ReadFile(src)
	if err != nil {
		return fmt.Errorf("reading the file %s file for copy. Error: %s", src, err)
	}

	err = os.WriteFile(dest, bytesRead, stats.Mode())
	if err != nil {
		return fmt.Errorf("writing the destination file %s for copy. Error: %s", dest, err)
	}
	return nil
}

func isContains(subList []string, list []string) bool {
	for _, sl := range subList {
		var found bool
		for _, l := range list {
			if sl == l {
				found = true
			}
		}
		if !found {
			return false
		}
	}
	return true
}