package sharedsock

import (
	"context"
	"errors"
	"fmt"
	"net"
	"net/netip"
	"os"
	"sync"
	"testing"
	"time"

	"github.com/pion/stun"
	log "github.com/sirupsen/logrus"
	"github.com/stretchr/testify/require"
	"golang.org/x/sync/errgroup"
)

func TestShouldReadSTUNOnReadFrom(t *testing.T) {

	// create raw socket on a port
	testingPort := 51821
	rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
	require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
	err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second))
	require.NoError(t, err, "unable to set deadline, error: %s", err)

	wg := sync.WaitGroup{}
	wg.Add(1)

	// when reading from the raw socket
	buf := make([]byte, 1500)
	rcvMSG := &stun.Message{
		Raw: buf,
	}
	ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
	defer cancel()

	go func() {
		select {
		case <-ctx.Done():
			return
		default:
			_, _, err := rawSock.ReadFrom(buf)
			if err != nil {
				log.Errorf("error while reading packet %s", err)
				return
			}

			err = rcvMSG.Decode()
			if err != nil {
				log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
				return
			}
			wg.Done()
		}

	}()

	// and sending STUN packet to the shared port, the packet has to be handled
	udpListener, err := net.ListenUDP("udp", &net.UDPAddr{Port: 12345, IP: net.ParseIP("127.0.0.1")})
	require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
	defer udpListener.Close()
	stunMSG, err := stun.Build(stun.NewType(stun.MethodBinding, stun.ClassRequest), stun.TransactionID,
		stun.Fingerprint,
	)
	require.NoError(t, err, "unable to build stun msg, error: %s", err)
	_, err = udpListener.WriteTo(stunMSG.Raw, net.UDPAddrFromAddrPort(netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", testingPort))))
	require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)

	// the packet has to be handled and be a STUN packet
	wg.Wait()
	require.EqualValues(t, stunMSG.TransactionID, rcvMSG.TransactionID, "transaction id values did't match")
}

func TestShouldNotReadNonSTUNPackets(t *testing.T) {
	testingPort := 39439
	rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
	require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
	defer rawSock.Close()

	buf := make([]byte, 1500)
	err = rawSock.SetReadDeadline(time.Now().Add(time.Second))
	require.NoError(t, err, "unable to set deadline, error: %s", err)

	errGrp := errgroup.Group{}
	errGrp.Go(func() error {
		_, _, err := rawSock.ReadFrom(buf)
		return err
	})
	nonStun := []byte("netbird")
	udpListener, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0, IP: net.ParseIP("127.0.0.1")})
	require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
	defer udpListener.Close()
	remote := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(fmt.Sprintf("127.0.0.1:%d", testingPort)))
	_, err = udpListener.WriteTo(nonStun, remote)
	require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)

	err = errGrp.Wait()
	require.Error(t, err, "should receive an error")
	if !errors.Is(err, os.ErrDeadlineExceeded) {
		t.Errorf("error should be I/O timeout, got: %s", err)
	}
}

func TestWriteTo(t *testing.T) {
	udpListener, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 0, IP: net.ParseIP("127.0.0.1")})
	require.NoError(t, err, "received an error while creating regular listener, error: %s", err)
	defer udpListener.Close()

	testingPort := 39440
	rawSock, err := Listen(testingPort, NewIncomingSTUNFilter())
	require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)
	defer rawSock.Close()

	buf := make([]byte, 1500)
	err = udpListener.SetReadDeadline(time.Now().Add(3 * time.Second))
	require.NoError(t, err, "unable to set deadline, error: %s", err)

	errGrp := errgroup.Group{}
	var remoteAdr net.Addr
	var rcvBytes int
	errGrp.Go(func() error {
		n, a, err := udpListener.ReadFrom(buf)
		remoteAdr = a
		rcvBytes = n
		return err
	})

	msg := []byte("netbird")
	_, err = rawSock.WriteTo(msg, udpListener.LocalAddr())
	require.NoError(t, err, "received an error while writing the stun listener, error: %s", err)

	err = errGrp.Wait()
	require.NoError(t, err, "received an error while reading the packet, error: %s", err)

	require.EqualValues(t, string(msg), string(buf[:rcvBytes]), "received message should match")

	udpRcv, ok := remoteAdr.(*net.UDPAddr)
	require.True(t, ok, "udp address conversion didn't work")

	require.EqualValues(t, testingPort, udpRcv.Port, "received address port didn't match")
}

func TestSharedSocket_Close(t *testing.T) {
	rawSock, err := Listen(39440, NewIncomingSTUNFilter())
	require.NoError(t, err, "received an error while creating STUN listener, error: %s", err)

	errGrp := errgroup.Group{}

	errGrp.Go(func() error {
		buf := make([]byte, 1500)
		_, _, err := rawSock.ReadFrom(buf)
		return err
	})
	_ = rawSock.Close()
	err = errGrp.Wait()
	if err != ErrSharedSockStopped {
		t.Errorf("invalid error response: %s", err)
	}
}