package relay

import (
	"context"
	"fmt"
	"net"
	"sync"
	"time"

	"github.com/pion/stun/v2"
	"github.com/pion/turn/v3"
	log "github.com/sirupsen/logrus"
)

// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
	URI  *stun.URI
	Err  error
	Addr string
}

// ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
	defer func() {
		if probeErr != nil {
			log.Debugf("stun probe error from %s: %s", uri, probeErr)
		}
	}()

	client, err := stun.DialURI(uri, &stun.DialConfig{})
	if err != nil {
		probeErr = fmt.Errorf("dial: %w", err)
		return
	}

	defer func() {
		if err := client.Close(); err != nil && probeErr == nil {
			probeErr = fmt.Errorf("close: %w", err)
		}
	}()

	done := make(chan struct{})
	if err = client.Start(stun.MustBuild(stun.TransactionID, stun.BindingRequest), func(res stun.Event) {
		if res.Error != nil {
			probeErr = fmt.Errorf("request: %w", err)
			return
		}

		var xorAddr stun.XORMappedAddress
		if getErr := xorAddr.GetFrom(res.Message); getErr != nil {
			probeErr = fmt.Errorf("get xor addr: %w", err)
			return
		}

		log.Debugf("stun probe received address from %s: %s", uri, xorAddr)
		addr = xorAddr.String()

		done <- struct{}{}
	}); err != nil {
		probeErr = fmt.Errorf("client: %w", err)
		return
	}

	select {
	case <-ctx.Done():
		probeErr = fmt.Errorf("stun request: %w", ctx.Err())
		return
	case <-done:
	}

	return addr, nil
}

// ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
	defer func() {
		if probeErr != nil {
			log.Debugf("turn probe error from %s: %s", uri, probeErr)
		}
	}()

	turnServerAddr := fmt.Sprintf("%s:%d", uri.Host, uri.Port)

	var conn net.PacketConn
	switch uri.Proto {
	case stun.ProtoTypeUDP:
		var err error
		conn, err = net.ListenPacket("udp", "")
		if err != nil {
			probeErr = fmt.Errorf("listen: %w", err)
			return
		}
	case stun.ProtoTypeTCP:
		dialer := net.Dialer{}
		tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
		if err != nil {
			probeErr = fmt.Errorf("dial: %w", err)
			return
		}
		conn = turn.NewSTUNConn(tcpConn)
	default:
		probeErr = fmt.Errorf("conn: unknown proto: %s", uri.Proto)
		return
	}

	defer func() {
		if err := conn.Close(); err != nil && probeErr == nil {
			probeErr = fmt.Errorf("conn close: %w", err)
		}
	}()

	cfg := &turn.ClientConfig{
		STUNServerAddr: turnServerAddr,
		TURNServerAddr: turnServerAddr,
		Conn:           conn,
		Username:       uri.Username,
		Password:       uri.Password,
	}
	client, err := turn.NewClient(cfg)
	if err != nil {
		probeErr = fmt.Errorf("create client: %w", err)
		return
	}
	defer client.Close()

	if err := client.Listen(); err != nil {
		probeErr = fmt.Errorf("client listen: %w", err)
		return
	}

	relayConn, err := client.Allocate()
	if err != nil {
		probeErr = fmt.Errorf("allocate: %w", err)
		return
	}
	defer func() {
		if err := relayConn.Close(); err != nil && probeErr == nil {
			probeErr = fmt.Errorf("close relay conn: %w", err)
		}
	}()

	log.Debugf("turn probe relay address from %s: %s", uri, relayConn.LocalAddr())

	return relayConn.LocalAddr().String(), nil
}

// ProbeAll probes all given servers asynchronously and returns the results
func ProbeAll(
	ctx context.Context,
	fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
	relays []*stun.URI,
) []ProbeResult {
	results := make([]ProbeResult, len(relays))

	var wg sync.WaitGroup
	for i, uri := range relays {
		ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
		defer cancel()

		wg.Add(1)
		go func(res *ProbeResult, stunURI *stun.URI) {
			defer wg.Done()
			res.URI = stunURI
			res.Addr, res.Err = fn(ctx, stunURI)
		}(&results[i], uri)
	}

	wg.Wait()

	return results
}