//go:build !android

package routemanager

import (
	"context"
	"fmt"
	"net"
	"net/netip"
	"sync"

	"github.com/google/nftables"
	"github.com/google/nftables/binaryutil"
	"github.com/google/nftables/expr"
	log "github.com/sirupsen/logrus"
)

const (
	nftablesTable                  = "netbird-rt"
	nftablesRoutingForwardingChain = "netbird-rt-fwd"
	nftablesRoutingNatChain        = "netbird-rt-nat"
)

// constants needed to create nftable rules
const (
	ipv4Len                  = 4
	ipv4SrcOffset            = 12
	ipv4DestOffset           = 16
	ipv6Len                  = 16
	ipv6SrcOffset            = 8
	ipv6DestOffset           = 24
	exprDirectionSource      = "source"
	exprDirectionDestination = "destination"
)

// some presets for building nftable rules
var (
	zeroXor = binaryutil.NativeEndian.PutUint32(0)

	zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)

	exprAllowRelatedEstablished = []expr.Any{
		&expr.Ct{
			Register:       1,
			SourceRegister: false,
			Key:            0,
		},
		&expr.Bitwise{
			DestRegister:   1,
			SourceRegister: 1,
			Len:            4,
			Mask:           []uint8{0x6, 0x0, 0x0, 0x0},
			Xor:            zeroXor,
		},
		&expr.Cmp{
			Register: 1,
			Data:     binaryutil.NativeEndian.PutUint32(0),
		},
		&expr.Counter{},
		&expr.Verdict{
			Kind: expr.VerdictAccept,
		},
	}

	exprCounterAccept = []expr.Any{
		&expr.Counter{},
		&expr.Verdict{
			Kind: expr.VerdictAccept,
		},
	}
)

type nftablesManager struct {
	ctx       context.Context
	stop      context.CancelFunc
	conn      *nftables.Conn
	tableIPv4 *nftables.Table
	tableIPv6 *nftables.Table
	chains    map[string]map[string]*nftables.Chain
	rules     map[string]*nftables.Rule
	mux       sync.Mutex
}

// CleanRoutingRules cleans existing nftables rules from the system
func (n *nftablesManager) CleanRoutingRules() {
	n.mux.Lock()
	defer n.mux.Unlock()
	log.Debug("flushing tables")
	if n.tableIPv4 != nil && n.tableIPv6 != nil {
		n.conn.FlushTable(n.tableIPv6)
		n.conn.FlushTable(n.tableIPv4)
	}
	log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}

// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
// if they don't exist, we create them
func (n *nftablesManager) RestoreOrCreateContainers() error {
	n.mux.Lock()
	defer n.mux.Unlock()

	if n.tableIPv6 != nil && n.tableIPv4 != nil {
		log.Debugf("nftables: containers already restored, skipping")
		return nil
	}

	tables, err := n.conn.ListTables()
	if err != nil {
		return fmt.Errorf("nftables: unable to list tables: %v", err)
	}

	for _, table := range tables {
		if table.Name == nftablesTable {
			if table.Family == nftables.TableFamilyIPv4 {
				n.tableIPv4 = table
				continue
			}
			n.tableIPv6 = table
		}
	}

	if n.tableIPv4 == nil {
		n.tableIPv4 = n.conn.AddTable(&nftables.Table{
			Name:   nftablesTable,
			Family: nftables.TableFamilyIPv4,
		})
	}

	if n.tableIPv6 == nil {
		n.tableIPv6 = n.conn.AddTable(&nftables.Table{
			Name:   nftablesTable,
			Family: nftables.TableFamilyIPv6,
		})
	}

	chains, err := n.conn.ListChains()
	if err != nil {
		return fmt.Errorf("nftables: unable to list chains: %v", err)
	}

	n.chains[ipv4] = make(map[string]*nftables.Chain)
	n.chains[ipv6] = make(map[string]*nftables.Chain)

	for _, chain := range chains {
		switch {
		case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
			n.chains[ipv4][chain.Name] = chain
		case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
			n.chains[ipv6][chain.Name] = chain
		}
	}

	if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
		n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
			Name:     nftablesRoutingForwardingChain,
			Table:    n.tableIPv4,
			Hooknum:  nftables.ChainHookForward,
			Priority: nftables.ChainPriorityNATDest + 1,
			Type:     nftables.ChainTypeFilter,
		})
	}

	if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
		n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
			Name:     nftablesRoutingNatChain,
			Table:    n.tableIPv4,
			Hooknum:  nftables.ChainHookPostrouting,
			Priority: nftables.ChainPriorityNATSource - 1,
			Type:     nftables.ChainTypeNAT,
		})
	}

	if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
		n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
			Name:     nftablesRoutingForwardingChain,
			Table:    n.tableIPv6,
			Hooknum:  nftables.ChainHookForward,
			Priority: nftables.ChainPriorityNATDest + 1,
			Type:     nftables.ChainTypeFilter,
		})
	}

	if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
		n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
			Name:     nftablesRoutingNatChain,
			Table:    n.tableIPv6,
			Hooknum:  nftables.ChainHookPostrouting,
			Priority: nftables.ChainPriorityNATSource - 1,
			Type:     nftables.ChainTypeNAT,
		})
	}

	err = n.refreshRulesMap()
	if err != nil {
		return err
	}

	n.checkOrCreateDefaultForwardingRules()
	err = n.conn.Flush()
	if err != nil {
		return fmt.Errorf("nftables: unable to initialize table: %v", err)
	}
	return nil
}

// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (n *nftablesManager) refreshRulesMap() error {
	for _, registeredChains := range n.chains {
		for _, chain := range registeredChains {
			rules, err := n.conn.GetRules(chain.Table, chain)
			if err != nil {
				return fmt.Errorf("nftables: unable to list rules: %v", err)
			}
			for _, rule := range rules {
				if len(rule.UserData) > 0 {
					n.rules[string(rule.UserData)] = rule
				}
			}
		}
	}
	return nil
}

// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
	_, foundIPv4 := n.rules[ipv4Forwarding]
	if !foundIPv4 {
		n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
			Table:    n.tableIPv4,
			Chain:    n.chains[ipv4][nftablesRoutingForwardingChain],
			Exprs:    exprAllowRelatedEstablished,
			UserData: []byte(ipv4Forwarding),
		})
	}

	_, foundIPv6 := n.rules[ipv6Forwarding]
	if !foundIPv6 {
		n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
			Table:    n.tableIPv6,
			Chain:    n.chains[ipv6][nftablesRoutingForwardingChain],
			Exprs:    exprAllowRelatedEstablished,
			UserData: []byte(ipv6Forwarding),
		})
	}
}

// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
	n.mux.Lock()
	defer n.mux.Unlock()

	err := n.refreshRulesMap()
	if err != nil {
		return err
	}

	err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
	if err != nil {
		return err
	}
	err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
	if err != nil {
		return err
	}

	if pair.masquerade {
		err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
		if err != nil {
			return err
		}
		err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
		if err != nil {
			return err
		}
	}

	err = n.conn.Flush()
	if err != nil {
		return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
	}
	return nil
}

// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {

	prefix := netip.MustParsePrefix(pair.source)

	sourceExp := generateCIDRMatcherExpressions("source", pair.source)
	destExp := generateCIDRMatcherExpressions("destination", pair.destination)

	var expression []expr.Any
	if isNat {
		expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
	} else {
		expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
	}

	ruleKey := genKey(format, pair.ID)

	_, exists := n.rules[ruleKey]
	if exists {
		err := n.removeRoutingRule(format, pair)
		if err != nil {
			return err
		}
	}

	if prefix.Addr().Unmap().Is4() {
		n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
			Table:    n.tableIPv4,
			Chain:    n.chains[ipv4][chain],
			Exprs:    expression,
			UserData: []byte(ruleKey),
		})
	} else {
		n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
			Table:    n.tableIPv6,
			Chain:    n.chains[ipv6][chain],
			Exprs:    expression,
			UserData: []byte(ruleKey),
		})
	}
	return nil
}

// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
	n.mux.Lock()
	defer n.mux.Unlock()

	err := n.refreshRulesMap()
	if err != nil {
		return err
	}

	err = n.removeRoutingRule(forwardingFormat, pair)
	if err != nil {
		return err
	}

	err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
	if err != nil {
		return err
	}

	err = n.removeRoutingRule(natFormat, pair)
	if err != nil {
		return err
	}

	err = n.removeRoutingRule(inNatFormat, getInPair(pair))
	if err != nil {
		return err
	}

	err = n.conn.Flush()
	if err != nil {
		return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
	}
	log.Debugf("nftables: removed rules for %s", pair.destination)
	return nil
}

// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
	ruleKey := genKey(format, pair.ID)

	rule, found := n.rules[ruleKey]
	if found {
		ruleType := "forwarding"
		if rule.Chain.Type == nftables.ChainTypeNAT {
			ruleType = "nat"
		}

		err := n.conn.DelRule(rule)
		if err != nil {
			return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
		}

		log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)

		delete(n.rules, ruleKey)
	}
	return nil
}

// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
	switch {
	case direction == exprDirectionSource && isIPv4:
		return ipv4SrcOffset, ipv4Len, zeroXor
	case direction == exprDirectionDestination && isIPv4:
		return ipv4DestOffset, ipv4Len, zeroXor
	case direction == exprDirectionSource && isIPv6:
		return ipv6SrcOffset, ipv6Len, zeroXor6
	case direction == exprDirectionDestination && isIPv6:
		return ipv6DestOffset, ipv6Len, zeroXor6
	default:
		panic("no matched payload directive")
	}
}

// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
	ip, network, _ := net.ParseCIDR(cidr)
	ipToAdd, _ := netip.AddrFromSlice(ip)
	add := ipToAdd.Unmap()

	offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())

	return []expr.Any{
		// fetch src add
		&expr.Payload{
			DestRegister: 1,
			Base:         expr.PayloadBaseNetworkHeader,
			Offset:       offSet,
			Len:          packetLen,
		},
		// net mask
		&expr.Bitwise{
			DestRegister:   1,
			SourceRegister: 1,
			Len:            packetLen,
			Mask:           network.Mask,
			Xor:            zeroXor,
		},
		// net address
		&expr.Cmp{
			Register: 1,
			Data:     add.AsSlice(),
		},
	}
}