package routemanager

import (
	"context"
	"fmt"
	"reflect"
	"time"

	"github.com/hashicorp/go-multierror"
	log "github.com/sirupsen/logrus"

	nberrors "github.com/netbirdio/netbird/client/errors"
	"github.com/netbirdio/netbird/client/iface"
	nbdns "github.com/netbirdio/netbird/client/internal/dns"
	"github.com/netbirdio/netbird/client/internal/peer"
	"github.com/netbirdio/netbird/client/internal/peerstore"
	"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
	"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
	"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
	"github.com/netbirdio/netbird/client/internal/routemanager/static"
	"github.com/netbirdio/netbird/route"
)

const (
	handlerTypeDynamic = iota
	handlerTypeDomain
	handlerTypeStatic
)

type routerPeerStatus struct {
	connected bool
	relayed   bool
	latency   time.Duration
}

type routesUpdate struct {
	updateSerial uint64
	routes       []*route.Route
}

// RouteHandler defines the interface for handling routes
type RouteHandler interface {
	String() string
	AddRoute(ctx context.Context) error
	RemoveRoute() error
	AddAllowedIPs(peerKey string) error
	RemoveAllowedIPs() error
}

type clientNetwork struct {
	ctx                 context.Context
	cancel              context.CancelFunc
	statusRecorder      *peer.Status
	wgInterface         iface.IWGIface
	routes              map[route.ID]*route.Route
	routeUpdate         chan routesUpdate
	peerStateUpdate     chan struct{}
	routePeersNotifiers map[string]chan struct{}
	currentChosen       *route.Route
	handler             RouteHandler
	updateSerial        uint64
}

func newClientNetworkWatcher(
	ctx context.Context,
	dnsRouteInterval time.Duration,
	wgInterface iface.IWGIface,
	statusRecorder *peer.Status,
	rt *route.Route,
	routeRefCounter *refcounter.RouteRefCounter,
	allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
	dnsServer nbdns.Server,
	peerStore *peerstore.Store,
	useNewDNSRoute bool,
) *clientNetwork {
	ctx, cancel := context.WithCancel(ctx)

	client := &clientNetwork{
		ctx:                 ctx,
		cancel:              cancel,
		statusRecorder:      statusRecorder,
		wgInterface:         wgInterface,
		routes:              make(map[route.ID]*route.Route),
		routePeersNotifiers: make(map[string]chan struct{}),
		routeUpdate:         make(chan routesUpdate),
		peerStateUpdate:     make(chan struct{}),
		handler: handlerFromRoute(
			rt,
			routeRefCounter,
			allowedIPsRefCounter,
			dnsRouteInterval,
			statusRecorder,
			wgInterface,
			dnsServer,
			peerStore,
			useNewDNSRoute,
		),
	}
	return client
}

func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
	routePeerStatuses := make(map[route.ID]routerPeerStatus)
	for _, r := range c.routes {
		peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
		if err != nil {
			log.Debugf("couldn't fetch peer state: %v", err)
			continue
		}
		routePeerStatuses[r.ID] = routerPeerStatus{
			connected: peerStatus.ConnStatus == peer.StatusConnected,
			relayed:   peerStatus.Relayed,
			latency:   peerStatus.Latency,
		}
	}
	return routePeerStatuses
}

// getBestRouteFromStatuses determines the most optimal route from the available routes
// within a clientNetwork, taking into account peer connection status, route metrics, and
// preference for non-relayed and direct connections.
//
// It follows these prioritization rules:
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
// * Latency: Routes with lower latency are prioritized.
// * we compare the current score + 10ms to the chosen score to avoid flapping between routes
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]routerPeerStatus) route.ID {
	chosen := route.ID("")
	chosenScore := float64(0)
	currScore := float64(0)

	currID := route.ID("")
	if c.currentChosen != nil {
		currID = c.currentChosen.ID
	}

	for _, r := range c.routes {
		tempScore := float64(0)
		peerStatus, found := routePeerStatuses[r.ID]
		if !found || !peerStatus.connected {
			continue
		}

		if r.Metric < route.MaxMetric {
			metricDiff := route.MaxMetric - r.Metric
			tempScore = float64(metricDiff) * 10
		}

		// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
		latency := 999 * time.Millisecond
		if peerStatus.latency != 0 {
			latency = peerStatus.latency
		} else {
			log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
		}

		// avoid negative tempScore on the higher latency calculation
		if latency > 1*time.Second {
			latency = 999 * time.Millisecond
		}

		// higher latency is worse score
		tempScore += 1 - latency.Seconds()

		if !peerStatus.relayed {
			tempScore++
		}

		if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
			chosen = r.ID
			chosenScore = tempScore
		}

		if chosen == "" && currID == "" {
			chosen = r.ID
			chosenScore = tempScore
		}

		if r.ID == currID {
			currScore = tempScore
		}
	}

	log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)

	switch {
	case chosen == "":
		var peers []string
		for _, r := range c.routes {
			peers = append(peers, r.Peer)
		}

		log.Warnf("The network [%v] has not been assigned a routing peer as no peers from the list %s are currently connected", c.handler, peers)
	case chosen != currID:
		// we compare the current score + 10ms to the chosen score to avoid flapping between routes
		if currScore != 0 && currScore+0.01 > chosenScore {
			log.Debugf("Keeping current routing peer because the score difference with latency is less than 0.01(10ms), current: %f, new: %f", currScore, chosenScore)
			return currID
		}
		var p string
		if rt := c.routes[chosen]; rt != nil {
			p = rt.Peer
		}
		log.Infof("New chosen route is %s with peer %s with score %f for network [%v]", chosen, p, chosenScore, c.handler)
	}

	return chosen
}

func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
	for {
		select {
		case <-ctx.Done():
			return
		case <-closer:
			return
		case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
			state, err := c.statusRecorder.GetPeer(peerKey)
			if err != nil || state.ConnStatus == peer.StatusConnecting {
				continue
			}
			peerStateUpdate <- struct{}{}
			log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus)
		}
	}
}

func (c *clientNetwork) startPeersStatusChangeWatcher() {
	for _, r := range c.routes {
		_, found := c.routePeersNotifiers[r.Peer]
		if found {
			continue
		}

		closerChan := make(chan struct{})
		c.routePeersNotifiers[r.Peer] = closerChan
		go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
	}
}

func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
	if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
		log.Warnf("Failed to update peer state: %v", err)
	}

	if err := c.handler.RemoveAllowedIPs(); err != nil {
		return fmt.Errorf("remove allowed IPs: %w", err)
	}
	return nil
}

func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
	if c.currentChosen == nil {
		return nil
	}

	var merr *multierror.Error

	if err := c.removeRouteFromWireGuardPeer(); err != nil {
		merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
	}
	if err := c.handler.RemoveRoute(); err != nil {
		merr = multierror.Append(merr, fmt.Errorf("remove route: %w", err))
	}

	return nberrors.FormatErrorOrNil(merr)
}

func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
	routerPeerStatuses := c.getRouterPeerStatuses()

	newChosenID := c.getBestRouteFromStatuses(routerPeerStatuses)

	// If no route is chosen, remove the route from the peer and system
	if newChosenID == "" {
		if err := c.removeRouteFromPeerAndSystem(); err != nil {
			return fmt.Errorf("remove route for peer %s: %w", c.currentChosen.Peer, err)
		}

		c.currentChosen = nil

		return nil
	}

	// If the chosen route is the same as the current route, do nothing
	if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
		c.currentChosen.IsEqual(c.routes[newChosenID]) {
		return nil
	}

	if c.currentChosen == nil {
		// If they were not previously assigned to another peer, add routes to the system first
		if err := c.handler.AddRoute(c.ctx); err != nil {
			return fmt.Errorf("add route: %w", err)
		}
	} else {
		// Otherwise, remove the allowed IPs from the previous peer first
		if err := c.removeRouteFromWireGuardPeer(); err != nil {
			return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
		}
	}

	c.currentChosen = c.routes[newChosenID]

	if err := c.handler.AddAllowedIPs(c.currentChosen.Peer); err != nil {
		return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
	}

	err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
	if err != nil {
		return fmt.Errorf("add peer state route: %w", err)
	}
	return nil
}

func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
	go func() {
		c.routeUpdate <- update
	}()
}

func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
	isUpdateMapDifferent := false
	updateMap := make(map[route.ID]*route.Route)

	for _, r := range update.routes {
		updateMap[r.ID] = r
	}

	if len(c.routes) != len(updateMap) {
		isUpdateMapDifferent = true
	}

	for id, r := range c.routes {
		_, found := updateMap[id]
		if !found {
			close(c.routePeersNotifiers[r.Peer])
			delete(c.routePeersNotifiers, r.Peer)
			isUpdateMapDifferent = true
			continue
		}
		if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
			isUpdateMapDifferent = true
		}
	}

	c.routes = updateMap
	return isUpdateMapDifferent
}

// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
// All the processing related to the client network should be done here. Thread-safe.
func (c *clientNetwork) peersStateAndUpdateWatcher() {
	for {
		select {
		case <-c.ctx.Done():
			log.Debugf("Stopping watcher for network [%v]", c.handler)
			if err := c.removeRouteFromPeerAndSystem(); err != nil {
				log.Errorf("Failed to remove routes for [%v]: %v", c.handler, err)
			}
			return
		case <-c.peerStateUpdate:
			err := c.recalculateRouteAndUpdatePeerAndSystem()
			if err != nil {
				log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
			}
		case update := <-c.routeUpdate:
			if update.updateSerial < c.updateSerial {
				log.Warnf("Received a routes update with smaller serial number (%d -> %d), ignoring it", c.updateSerial, update.updateSerial)
				continue
			}

			log.Debugf("Received a new client network route update for [%v]", c.handler)

			// hash update somehow
			isTrueRouteUpdate := c.handleUpdate(update)

			c.updateSerial = update.updateSerial

			if isTrueRouteUpdate {
				log.Debug("Client network update contains different routes, recalculating routes")
				err := c.recalculateRouteAndUpdatePeerAndSystem()
				if err != nil {
					log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
				}
			} else {
				log.Debug("Route update is not different, skipping route recalculation")
			}

			c.startPeersStatusChangeWatcher()
		}
	}
}

func handlerFromRoute(
	rt *route.Route,
	routeRefCounter *refcounter.RouteRefCounter,
	allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
	dnsRouterInteval time.Duration,
	statusRecorder *peer.Status,
	wgInterface iface.IWGIface,
	dnsServer nbdns.Server,
	peerStore *peerstore.Store,
	useNewDNSRoute bool,
) RouteHandler {
	switch handlerType(rt, useNewDNSRoute) {
	case handlerTypeDomain:
		return dnsinterceptor.New(
			rt,
			routeRefCounter,
			allowedIPsRefCounter,
			statusRecorder,
			dnsServer,
			peerStore,
		)
	case handlerTypeDynamic:
		dns := nbdns.NewServiceViaMemory(wgInterface)
		return dynamic.NewRoute(
			rt,
			routeRefCounter,
			allowedIPsRefCounter,
			dnsRouterInteval,
			statusRecorder,
			wgInterface,
			fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
		)
	default:
		return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
	}
}

func handlerType(rt *route.Route, useNewDNSRoute bool) int {
	if !rt.IsDynamic() {
		return handlerTypeStatic
	}

	if useNewDNSRoute {
		return handlerTypeDomain
	}
	return handlerTypeDynamic
}