package internal

import (
	"context"
	"fmt"
	"github.com/golang-jwt/jwt"
	"github.com/stretchr/testify/require"
	"io"
	"net/http"
	"net/url"
	"strings"
	"testing"
	"time"
)

type mockHTTPClient struct {
	code         int
	resBody      string
	reqBody      string
	MaxReqs      int
	count        int
	countResBody string
	err          error
}

func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
	body, err := io.ReadAll(req.Body)
	if err == nil {
		c.reqBody = string(body)
	}

	if c.MaxReqs > c.count {
		c.count++
		return &http.Response{
			StatusCode: c.code,
			Body:       io.NopCloser(strings.NewReader(c.countResBody)),
		}, c.err
	}

	return &http.Response{
		StatusCode: c.code,
		Body:       io.NopCloser(strings.NewReader(c.resBody)),
	}, c.err
}

func TestHosted_RequestDeviceCode(t *testing.T) {
	type test struct {
		name             string
		inputResBody     string
		inputReqCode     int
		inputReqError    error
		testingErrFunc   require.ErrorAssertionFunc
		expectedErrorMSG string
		testingFunc      require.ComparisonAssertionFunc
		expectedOut      DeviceAuthInfo
		expectedMSG      string
		expectPayload    string
	}

	expectedAudience := "ok"
	expectedClientID := "bla"
	expectedScope := "openid"
	form := url.Values{}
	form.Add("audience", expectedAudience)
	form.Add("client_id", expectedClientID)
	form.Add("scope", expectedScope)
	expectPayload := form.Encode()

	testCase1 := test{
		name:           "Payload Is Valid",
		expectPayload:  expectPayload,
		inputReqCode:   200,
		testingErrFunc: require.Error,
		testingFunc:    require.EqualValues,
	}

	testCase2 := test{
		name:             "Exit On Network Error",
		inputReqError:    fmt.Errorf("error"),
		testingErrFunc:   require.Error,
		expectedErrorMSG: "should return error",
		testingFunc:      require.EqualValues,
		expectPayload:    expectPayload,
	}

	testCase3 := test{
		name:             "Exit On Exit Code",
		inputReqCode:     400,
		testingErrFunc:   require.Error,
		expectedErrorMSG: "should return error",
		testingFunc:      require.EqualValues,
		expectPayload:    expectPayload,
	}
	testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
	testCase4 := test{
		name:           "Got Device Code",
		inputResBody:   fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
		expectPayload:  expectPayload,
		inputReqCode:   200,
		testingErrFunc: require.NoError,
		testingFunc:    require.EqualValues,
		expectedOut:    testCase4Out,
		expectedMSG:    "out should match",
	}

	for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
		t.Run(testCase.name, func(t *testing.T) {

			httpClient := mockHTTPClient{
				resBody: testCase.inputResBody,
				code:    testCase.inputReqCode,
				err:     testCase.inputReqError,
			}

			hosted := Hosted{
				Audience:           expectedAudience,
				ClientID:           expectedClientID,
				Scope:              expectedScope,
				TokenEndpoint:      "test.hosted.com/token",
				DeviceAuthEndpoint: "test.hosted.com/device/auth",
				HTTPClient:         &httpClient,
			}

			authInfo, err := hosted.RequestDeviceCode(context.TODO())
			testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)

			require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")

			testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG)

		})
	}
}

func TestHosted_WaitToken(t *testing.T) {
	type test struct {
		name              string
		inputResBody      string
		inputReqCode      int
		inputReqError     error
		inputMaxReqs      int
		inputCountResBody string
		inputTimeout      time.Duration
		inputInfo         DeviceAuthInfo
		inputAudience     string
		testingErrFunc    require.ErrorAssertionFunc
		expectedErrorMSG  string
		testingFunc       require.ComparisonAssertionFunc
		expectedOut       TokenInfo
		expectedMSG       string
		expectPayload     string
	}

	defaultInfo := DeviceAuthInfo{
		DeviceCode: "test",
		ExpiresIn:  10,
		Interval:   1,
	}

	clientID := "test"

	form := url.Values{}
	form.Add("grant_type", HostedGrantType)
	form.Add("device_code", defaultInfo.DeviceCode)
	form.Add("client_id", clientID)
	tokenReqPayload := form.Encode()

	testCase1 := test{
		name:           "Payload Is Valid",
		inputInfo:      defaultInfo,
		inputTimeout:   time.Duration(defaultInfo.ExpiresIn) * time.Second,
		inputReqCode:   200,
		testingErrFunc: require.Error,
		testingFunc:    require.EqualValues,
		expectPayload:  tokenReqPayload,
	}

	testCase2 := test{
		name:             "Exit On Network Error",
		inputInfo:        defaultInfo,
		inputTimeout:     time.Duration(defaultInfo.ExpiresIn) * time.Second,
		expectPayload:    tokenReqPayload,
		inputReqError:    fmt.Errorf("error"),
		testingErrFunc:   require.Error,
		expectedErrorMSG: "should return error",
		testingFunc:      require.EqualValues,
	}

	testCase3 := test{
		name:             "Exit On 4XX When Not Pending",
		inputInfo:        defaultInfo,
		inputTimeout:     time.Duration(defaultInfo.ExpiresIn) * time.Second,
		inputReqCode:     400,
		expectPayload:    tokenReqPayload,
		testingErrFunc:   require.Error,
		expectedErrorMSG: "should return error",
		testingFunc:      require.EqualValues,
	}

	testCase4 := test{
		name:             "Exit On Exit Code 5XX",
		inputInfo:        defaultInfo,
		inputTimeout:     time.Duration(defaultInfo.ExpiresIn) * time.Second,
		inputReqCode:     500,
		expectPayload:    tokenReqPayload,
		testingErrFunc:   require.Error,
		expectedErrorMSG: "should return error",
		testingFunc:      require.EqualValues,
	}

	testCase5 := test{
		name:             "Exit On Content Timeout",
		inputInfo:        defaultInfo,
		inputTimeout:     0 * time.Second,
		testingErrFunc:   require.Error,
		expectedErrorMSG: "should return error",
		testingFunc:      require.EqualValues,
	}

	audience := "test"

	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
	var hmacSampleSecret []byte
	tokenString, _ := token.SignedString(hmacSampleSecret)

	testCase6 := test{
		name:           "Exit On Invalid Audience",
		inputInfo:      defaultInfo,
		inputResBody:   fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
		inputTimeout:   time.Duration(defaultInfo.ExpiresIn) * time.Second,
		inputReqCode:   200,
		inputAudience:  "super test",
		testingErrFunc: require.Error,
		testingFunc:    require.EqualValues,
		expectPayload:  tokenReqPayload,
	}

	testCase7 := test{
		name:           "Received Token Info",
		inputInfo:      defaultInfo,
		inputResBody:   fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
		inputTimeout:   time.Duration(defaultInfo.ExpiresIn) * time.Second,
		inputReqCode:   200,
		inputAudience:  audience,
		testingErrFunc: require.NoError,
		testingFunc:    require.EqualValues,
		expectPayload:  tokenReqPayload,
		expectedOut:    TokenInfo{AccessToken: tokenString},
	}

	testCase8 := test{
		name:              "Received Token Info after Multiple tries",
		inputInfo:         defaultInfo,
		inputResBody:      fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
		inputTimeout:      time.Duration(defaultInfo.ExpiresIn) * time.Second,
		inputMaxReqs:      2,
		inputCountResBody: "{\"error\":\"authorization_pending\"}",
		inputReqCode:      200,
		inputAudience:     audience,
		testingErrFunc:    require.NoError,
		testingFunc:       require.EqualValues,
		expectPayload:     tokenReqPayload,
		expectedOut:       TokenInfo{AccessToken: tokenString},
	}

	for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7, testCase8} {
		t.Run(testCase.name, func(t *testing.T) {

			httpClient := mockHTTPClient{
				resBody:      testCase.inputResBody,
				code:         testCase.inputReqCode,
				err:          testCase.inputReqError,
				MaxReqs:      testCase.inputMaxReqs,
				countResBody: testCase.inputCountResBody,
			}

			hosted := Hosted{
				Audience:           testCase.inputAudience,
				ClientID:           clientID,
				TokenEndpoint:      "test.hosted.com/token",
				DeviceAuthEndpoint: "test.hosted.com/device/auth",
				HTTPClient:         &httpClient,
			}

			ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
			defer cancel()
			tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
			testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)

			require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")

			testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)

			require.GreaterOrEqualf(t, testCase.inputMaxReqs, httpClient.count, "should run %d times", testCase.inputMaxReqs)

		})
	}
}