package dns_test

import (
	"net"
	"testing"

	"github.com/miekg/dns"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"

	nbdns "github.com/netbirdio/netbird/client/internal/dns"
)

// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
	chain := nbdns.NewHandlerChain()

	// Create mock handlers for different priorities
	defaultHandler := &nbdns.MockHandler{}
	matchDomainHandler := &nbdns.MockHandler{}
	dnsRouteHandler := &nbdns.MockHandler{}

	// Setup handlers with different priorities
	chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
	chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
	chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)

	// Create test request
	r := new(dns.Msg)
	r.SetQuestion("example.com.", dns.TypeA)

	// Create test writer
	w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}

	// Setup expectations - only highest priority handler should be called
	dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
	matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
	defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()

	// Execute
	chain.ServeDNS(w, r)

	// Verify all expectations were met
	dnsRouteHandler.AssertExpectations(t)
	matchDomainHandler.AssertExpectations(t)
	defaultHandler.AssertExpectations(t)
}

// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
	tests := []struct {
		name            string
		handlerDomain   string
		queryDomain     string
		isWildcard      bool
		matchSubdomains bool
		shouldMatch     bool
	}{
		{
			name:            "exact match",
			handlerDomain:   "example.com.",
			queryDomain:     "example.com.",
			isWildcard:      false,
			matchSubdomains: false,
			shouldMatch:     true,
		},
		{
			name:            "subdomain with non-wildcard and MatchSubdomains true",
			handlerDomain:   "example.com.",
			queryDomain:     "sub.example.com.",
			isWildcard:      false,
			matchSubdomains: true,
			shouldMatch:     true,
		},
		{
			name:            "subdomain with non-wildcard and MatchSubdomains false",
			handlerDomain:   "example.com.",
			queryDomain:     "sub.example.com.",
			isWildcard:      false,
			matchSubdomains: false,
			shouldMatch:     false,
		},
		{
			name:            "wildcard match",
			handlerDomain:   "*.example.com.",
			queryDomain:     "sub.example.com.",
			isWildcard:      true,
			matchSubdomains: false,
			shouldMatch:     true,
		},
		{
			name:            "wildcard no match on apex",
			handlerDomain:   "*.example.com.",
			queryDomain:     "example.com.",
			isWildcard:      true,
			matchSubdomains: false,
			shouldMatch:     false,
		},
		{
			name:            "root zone match",
			handlerDomain:   ".",
			queryDomain:     "anything.com.",
			isWildcard:      false,
			matchSubdomains: false,
			shouldMatch:     true,
		},
		{
			name:            "no match different domain",
			handlerDomain:   "example.com.",
			queryDomain:     "example.org.",
			isWildcard:      false,
			matchSubdomains: false,
			shouldMatch:     false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			chain := nbdns.NewHandlerChain()
			var handler dns.Handler

			if tt.matchSubdomains {
				mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
				handler = mockSubHandler
				if tt.shouldMatch {
					mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
				}
			} else {
				mockHandler := &nbdns.MockHandler{}
				handler = mockHandler
				if tt.shouldMatch {
					mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
				}
			}

			pattern := tt.handlerDomain
			if tt.isWildcard {
				pattern = "*." + tt.handlerDomain[2:]
			}

			chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)

			r := new(dns.Msg)
			r.SetQuestion(tt.queryDomain, dns.TypeA)
			w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}

			chain.ServeDNS(w, r)

			if h, ok := handler.(*nbdns.MockHandler); ok {
				h.AssertExpectations(t)
			} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
				h.AssertExpectations(t)
			}
		})
	}
}

// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
	tests := []struct {
		name     string
		handlers []struct {
			pattern  string
			priority int
		}
		queryDomain     string
		expectedCalls   int
		expectedHandler int // index of the handler that should be called
	}{
		{
			name: "wildcard and exact same priority - exact should win",
			handlers: []struct {
				pattern  string
				priority int
			}{
				{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
				{pattern: "example.com.", priority: nbdns.PriorityDefault},
			},
			queryDomain:     "example.com.",
			expectedCalls:   1,
			expectedHandler: 1, // exact match handler should be called
		},
		{
			name: "higher priority wildcard over lower priority exact",
			handlers: []struct {
				pattern  string
				priority int
			}{
				{pattern: "example.com.", priority: nbdns.PriorityDefault},
				{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
			},
			queryDomain:     "test.example.com.",
			expectedCalls:   1,
			expectedHandler: 1, // higher priority wildcard handler should be called
		},
		{
			name: "multiple wildcards different priorities",
			handlers: []struct {
				pattern  string
				priority int
			}{
				{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
				{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
				{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
			},
			queryDomain:     "test.example.com.",
			expectedCalls:   1,
			expectedHandler: 2, // highest priority handler should be called
		},
		{
			name: "subdomain with mix of patterns",
			handlers: []struct {
				pattern  string
				priority int
			}{
				{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
				{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
				{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
			},
			queryDomain:     "sub.test.example.com.",
			expectedCalls:   1,
			expectedHandler: 2, // highest priority matching handler should be called
		},
		{
			name: "root zone with specific domain",
			handlers: []struct {
				pattern  string
				priority int
			}{
				{pattern: ".", priority: nbdns.PriorityDefault},
				{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
			},
			queryDomain:     "example.com.",
			expectedCalls:   1,
			expectedHandler: 1, // higher priority specific domain should win over root
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			chain := nbdns.NewHandlerChain()
			var handlers []*nbdns.MockHandler

			// Setup handlers and expectations
			for i := range tt.handlers {
				handler := &nbdns.MockHandler{}
				handlers = append(handlers, handler)

				// Set expectation based on whether this handler should be called
				if i == tt.expectedHandler {
					handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
				} else {
					handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
				}

				chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
			}

			// Create and execute request
			r := new(dns.Msg)
			r.SetQuestion(tt.queryDomain, dns.TypeA)
			w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
			chain.ServeDNS(w, r)

			// Verify expectations
			for _, handler := range handlers {
				handler.AssertExpectations(t)
			}
		})
	}
}

// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
	chain := nbdns.NewHandlerChain()

	// Create handlers
	handler1 := &nbdns.MockHandler{}
	handler2 := &nbdns.MockHandler{}
	handler3 := &nbdns.MockHandler{}

	// Add handlers in priority order
	chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
	chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
	chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)

	// Create test request
	r := new(dns.Msg)
	r.SetQuestion("example.com.", dns.TypeA)

	// Setup mock responses to simulate chain continuation
	handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
		// First handler signals continue
		w := args.Get(0).(*nbdns.ResponseWriterChain)
		resp := new(dns.Msg)
		resp.SetRcode(r, dns.RcodeNameError)
		resp.MsgHdr.Zero = true // Signal to continue
		assert.NoError(t, w.WriteMsg(resp))
	}).Once()

	handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
		// Second handler signals continue
		w := args.Get(0).(*nbdns.ResponseWriterChain)
		resp := new(dns.Msg)
		resp.SetRcode(r, dns.RcodeNameError)
		resp.MsgHdr.Zero = true
		assert.NoError(t, w.WriteMsg(resp))
	}).Once()

	handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
		// Last handler responds normally
		w := args.Get(0).(*nbdns.ResponseWriterChain)
		resp := new(dns.Msg)
		resp.SetRcode(r, dns.RcodeSuccess)
		assert.NoError(t, w.WriteMsg(resp))
	}).Once()

	// Execute
	w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
	chain.ServeDNS(w, r)

	// Verify all handlers were called in order
	handler1.AssertExpectations(t)
	handler2.AssertExpectations(t)
	handler3.AssertExpectations(t)
}

// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
	mock.Mock
}

func (m *mockResponseWriter) LocalAddr() net.Addr       { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr      { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error   { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error              { return nil }
func (m *mockResponseWriter) TsigStatus() error         { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool)       {}
func (m *mockResponseWriter) Hijack()                   {}

func TestHandlerChain_PriorityDeregistration(t *testing.T) {
	tests := []struct {
		name string
		ops  []struct {
			action   string // "add" or "remove"
			pattern  string
			priority int
		}
		query         string
		expectedCalls map[int]bool // map[priority]shouldBeCalled
	}{
		{
			name: "remove high priority keeps lower priority handler",
			ops: []struct {
				action   string
				pattern  string
				priority int
			}{
				{"add", "example.com.", nbdns.PriorityDNSRoute},
				{"add", "example.com.", nbdns.PriorityMatchDomain},
				{"remove", "example.com.", nbdns.PriorityDNSRoute},
			},
			query: "example.com.",
			expectedCalls: map[int]bool{
				nbdns.PriorityDNSRoute:    false,
				nbdns.PriorityMatchDomain: true,
			},
		},
		{
			name: "remove lower priority keeps high priority handler",
			ops: []struct {
				action   string
				pattern  string
				priority int
			}{
				{"add", "example.com.", nbdns.PriorityDNSRoute},
				{"add", "example.com.", nbdns.PriorityMatchDomain},
				{"remove", "example.com.", nbdns.PriorityMatchDomain},
			},
			query: "example.com.",
			expectedCalls: map[int]bool{
				nbdns.PriorityDNSRoute:    true,
				nbdns.PriorityMatchDomain: false,
			},
		},
		{
			name: "remove all handlers in order",
			ops: []struct {
				action   string
				pattern  string
				priority int
			}{
				{"add", "example.com.", nbdns.PriorityDNSRoute},
				{"add", "example.com.", nbdns.PriorityMatchDomain},
				{"add", "example.com.", nbdns.PriorityDefault},
				{"remove", "example.com.", nbdns.PriorityDNSRoute},
				{"remove", "example.com.", nbdns.PriorityMatchDomain},
			},
			query: "example.com.",
			expectedCalls: map[int]bool{
				nbdns.PriorityDNSRoute:    false,
				nbdns.PriorityMatchDomain: false,
				nbdns.PriorityDefault:     true,
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			chain := nbdns.NewHandlerChain()
			handlers := make(map[int]*nbdns.MockHandler)

			// Execute operations
			for _, op := range tt.ops {
				if op.action == "add" {
					handler := &nbdns.MockHandler{}
					handlers[op.priority] = handler
					chain.AddHandler(op.pattern, handler, op.priority, nil)
				} else {
					chain.RemoveHandler(op.pattern, op.priority)
				}
			}

			// Create test request
			r := new(dns.Msg)
			r.SetQuestion(tt.query, dns.TypeA)
			w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}

			// Setup expectations
			for priority, handler := range handlers {
				if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
					handler.On("ServeDNS", mock.Anything, r).Once()
				} else {
					handler.On("ServeDNS", mock.Anything, r).Maybe()
				}
			}

			// Execute request
			chain.ServeDNS(w, r)

			// Verify expectations
			for _, handler := range handlers {
				handler.AssertExpectations(t)
			}

			// Verify handler exists check
			for priority, shouldExist := range tt.expectedCalls {
				if shouldExist {
					assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
						"Handler chain should have handlers for pattern after removing priority %d", priority)
				}
			}
		})
	}
}

func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
	chain := nbdns.NewHandlerChain()

	testDomain := "example.com."
	testQuery := "test.example.com."

	// Create handlers with MatchSubdomains enabled
	routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
	matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
	defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}

	// Create test request that will be reused
	r := new(dns.Msg)
	r.SetQuestion(testQuery, dns.TypeA)

	// Add handlers in mixed order
	chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
	chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
	chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)

	// Test 1: Initial state with all three handlers
	w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
	// Highest priority handler (routeHandler) should be called
	routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()

	chain.ServeDNS(w, r)
	routeHandler.AssertExpectations(t)

	// Test 2: Remove highest priority handler
	chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
	assert.True(t, chain.HasHandlers(testDomain))

	w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
	// Now middle priority handler (matchHandler) should be called
	matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()

	chain.ServeDNS(w, r)
	matchHandler.AssertExpectations(t)

	// Test 3: Remove middle priority handler
	chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
	assert.True(t, chain.HasHandlers(testDomain))

	w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
	// Now lowest priority handler (defaultHandler) should be called
	defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()

	chain.ServeDNS(w, r)
	defaultHandler.AssertExpectations(t)

	// Test 4: Remove last handler
	chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
	assert.False(t, chain.HasHandlers(testDomain))
}