package client

import (
	"context"
	sigProto "github.com/netbirdio/netbird/signal/proto"
	"github.com/netbirdio/netbird/signal/server"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
	log "github.com/sirupsen/logrus"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/keepalive"
	"google.golang.org/grpc/metadata"
	"net"
	"sync"
	"time"
)

var _ = Describe("GrpcClient", func() {

	var (
		addr     string
		listener net.Listener
		server   *grpc.Server
	)

	BeforeEach(func() {
		server, listener = startSignal()
		addr = listener.Addr().String()

	})

	AfterEach(func() {
		server.Stop()
		listener.Close()
	})

	Describe("Exchanging messages", func() {
		Context("between connected peers", func() {
			It("should be successful", func() {

				var msgReceived sync.WaitGroup
				msgReceived.Add(2)

				var receivedOnA string
				var receivedOnB string

				// connect PeerA to Signal
				keyA, _ := wgtypes.GenerateKey()
				clientA := createSignalClient(addr, keyA)
				go func() {
					err := clientA.Receive(func(msg *sigProto.Message) error {
						receivedOnA = msg.GetBody().GetPayload()
						msgReceived.Done()
						return nil
					})
					if err != nil {
						return
					}
				}()
				clientA.WaitStreamConnected()

				// connect PeerB to Signal
				keyB, _ := wgtypes.GenerateKey()
				clientB := createSignalClient(addr, keyB)

				go func() {
					err := clientB.Receive(func(msg *sigProto.Message) error {
						receivedOnB = msg.GetBody().GetPayload()
						err := clientB.Send(&sigProto.Message{
							Key:       keyB.PublicKey().String(),
							RemoteKey: keyA.PublicKey().String(),
							Body:      &sigProto.Body{Payload: "pong"},
						})
						if err != nil {
							Fail("failed sending a message to PeerA")
						}
						msgReceived.Done()
						return nil
					})
					if err != nil {
						return
					}
				}()

				clientB.WaitStreamConnected()

				// PeerA initiates ping-pong
				err := clientA.Send(&sigProto.Message{
					Key:       keyA.PublicKey().String(),
					RemoteKey: keyB.PublicKey().String(),
					Body:      &sigProto.Body{Payload: "ping"},
				})
				if err != nil {
					Fail("failed sending a message to PeerB")
				}

				if waitTimeout(&msgReceived, 3*time.Second) {
					Fail("test timed out on waiting for peers to exchange messages")
				}

				Expect(receivedOnA).To(BeEquivalentTo("pong"))
				Expect(receivedOnB).To(BeEquivalentTo("ping"))

			})
		})
	})

	Describe("Connecting to the Signal stream channel", func() {
		Context("with a signal client", func() {
			It("should be successful", func() {

				key, _ := wgtypes.GenerateKey()
				client := createSignalClient(addr, key)
				go func() {
					err := client.Receive(func(msg *sigProto.Message) error {
						return nil
					})
					if err != nil {
						return
					}
				}()
				client.WaitStreamConnected()
				Expect(client).NotTo(BeNil())
			})
		})

		Context("with a raw client and no Id header", func() {
			It("should fail", func() {

				client := createRawSignalClient(addr)
				stream, err := client.ConnectStream(context.Background())
				if err != nil {
					Fail("error connecting to stream")
				}

				_, err = stream.Recv()

				Expect(stream).NotTo(BeNil())
				Expect(err).NotTo(BeNil())
			})
		})

		Context("with a raw client and with an Id header", func() {
			It("should be successful", func() {

				md := metadata.New(map[string]string{sigProto.HeaderId: "peer"})
				ctx := metadata.NewOutgoingContext(context.Background(), md)

				client := createRawSignalClient(addr)
				stream, err := client.ConnectStream(ctx)

				Expect(stream).NotTo(BeNil())
				Expect(err).To(BeNil())
			})
		})

	})

})

func createSignalClient(addr string, key wgtypes.Key) *GrpcClient {
	var sigTLSEnabled = false
	client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
	if err != nil {
		Fail("failed creating signal client")
	}
	return client
}

func createRawSignalClient(addr string) sigProto.SignalExchangeClient {
	ctx := context.Background()
	conn, err := grpc.DialContext(ctx, addr,
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithBlock(),
		grpc.WithKeepaliveParams(keepalive.ClientParameters{
			Time:    3 * time.Second,
			Timeout: 2 * time.Second,
		}))
	if err != nil {
		Fail("failed creating raw signal client")
	}

	return sigProto.NewSignalExchangeClient(conn)
}

func startSignal() (*grpc.Server, net.Listener) {
	lis, err := net.Listen("tcp", ":0")
	if err != nil {
		panic(err)
	}
	s := grpc.NewServer()
	sigProto.RegisterSignalExchangeServer(s, server.NewServer())
	go func() {
		if err := s.Serve(lis); err != nil {
			log.Fatalf("failed to serve: %v", err)
		}
	}()

	return s, lis
}

// waitTimeout waits for the waitgroup for the specified max timeout.
// Returns true if waiting timed out.
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
	c := make(chan struct{})
	go func() {
		defer close(c)
		wg.Wait()
	}()
	select {
	case <-c:
		return false // completed normally
	case <-time.After(timeout):
		return true // timed out
	}
}