package auth

import (
	"context"
	"crypto/sha256"
	"crypto/subtle"
	"encoding/base64"
	"errors"
	"fmt"
	"html/template"
	"net"
	"net/http"
	"net/url"
	"strings"
	"time"

	log "github.com/sirupsen/logrus"
	"golang.org/x/oauth2"

	"github.com/netbirdio/netbird/client/internal"
	"github.com/netbirdio/netbird/client/internal/templates"
)

var _ OAuthFlow = &PKCEAuthorizationFlow{}

const (
	queryState                = "state"
	queryCode                 = "code"
	queryError                = "error"
	queryErrorDesc            = "error_description"
	defaultPKCETimeoutSeconds = 300
)

// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
	providerConfig internal.PKCEAuthProviderConfig
	state          string
	codeVerifier   string
	oAuthConfig    *oauth2.Config
}

// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
	var availableRedirectURL string

	// find the first available redirect URL
	for _, redirectURL := range config.RedirectURLs {
		if !isRedirectURLPortUsed(redirectURL) {
			availableRedirectURL = redirectURL
			break
		}
	}

	if availableRedirectURL == "" {
		return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
	}

	cfg := &oauth2.Config{
		ClientID:     config.ClientID,
		ClientSecret: config.ClientSecret,
		Endpoint: oauth2.Endpoint{
			AuthURL:  config.AuthorizationEndpoint,
			TokenURL: config.TokenEndpoint,
		},
		RedirectURL: availableRedirectURL,
		Scopes:      strings.Split(config.Scope, " "),
	}

	return &PKCEAuthorizationFlow{
		providerConfig: config,
		oAuthConfig:    cfg,
	}, nil
}

// GetClientID returns the provider client id
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
	return p.providerConfig.ClientID
}

// RequestAuthInfo requests a authorization code login flow information.
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
	state, err := randomBytesInHex(24)
	if err != nil {
		return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
	}
	p.state = state

	codeVerifier, err := randomBytesInHex(64)
	if err != nil {
		return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
	}
	p.codeVerifier = codeVerifier

	codeChallenge := createCodeChallenge(codeVerifier)
	authURL := p.oAuthConfig.AuthCodeURL(
		state,
		oauth2.SetAuthURLParam("code_challenge_method", "S256"),
		oauth2.SetAuthURLParam("code_challenge", codeChallenge),
		oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
	)

	return AuthFlowInfo{
		VerificationURIComplete: authURL,
		ExpiresIn:               defaultPKCETimeoutSeconds,
	}, nil
}

// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
	tokenChan := make(chan *oauth2.Token, 1)
	errChan := make(chan error, 1)

	parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
	if err != nil {
		return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
	}

	server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
	defer func() {
		shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
		defer cancel()

		if err := server.Shutdown(shutdownCtx); err != nil {
			log.Errorf("failed to close the server: %v", err)
		}
	}()

	go p.startServer(server, tokenChan, errChan)

	select {
	case <-ctx.Done():
		return TokenInfo{}, ctx.Err()
	case token := <-tokenChan:
		return p.parseOAuthToken(token)
	case err := <-errChan:
		return TokenInfo{}, err
	}
}

func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
	mux := http.NewServeMux()
	mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
		token, err := p.handleRequest(req)
		if err != nil {
			renderPKCEFlowTmpl(w, err)
			errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
			return
		}

		renderPKCEFlowTmpl(w, nil)
		tokenChan <- token
	})

	server.Handler = mux
	if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
		errChan <- err
	}
}

func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
	query := req.URL.Query()

	if authError := query.Get(queryError); authError != "" {
		authErrorDesc := query.Get(queryErrorDesc)
		return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
	}

	// Prevent timing attacks on the state
	if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
		return nil, fmt.Errorf("invalid state")
	}

	code := query.Get(queryCode)
	if code == "" {
		return nil, fmt.Errorf("missing code")
	}

	return p.oAuthConfig.Exchange(
		req.Context(),
		code,
		oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
	)
}

func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
	tokenInfo := TokenInfo{
		AccessToken:  token.AccessToken,
		RefreshToken: token.RefreshToken,
		TokenType:    token.TokenType,
		ExpiresIn:    token.Expiry.Second(),
		UseIDToken:   p.providerConfig.UseIDToken,
	}
	if idToken, ok := token.Extra("id_token").(string); ok {
		tokenInfo.IDToken = idToken
	}

	// if a provider doesn't support an audience, use the Client ID for token verification
	audience := p.providerConfig.Audience
	if audience == "" {
		audience = p.providerConfig.ClientID
	}

	if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
		return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
	}

	return tokenInfo, nil
}

func createCodeChallenge(codeVerifier string) string {
	sha2 := sha256.Sum256([]byte(codeVerifier))
	return base64.RawURLEncoding.EncodeToString(sha2[:])
}

// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
func isRedirectURLPortUsed(redirectURL string) bool {
	parsedURL, err := url.Parse(redirectURL)
	if err != nil {
		log.Errorf("failed to parse redirect URL: %v", err)
		return true
	}

	addr := fmt.Sprintf(":%s", parsedURL.Port())
	conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
	if err != nil {
		return false
	}
	defer func() {
		if err := conn.Close(); err != nil {
			log.Errorf("error while closing the connection: %v", err)
		}
	}()

	return true
}

func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
	tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}

	data := make(map[string]string)
	if authError != nil {
		data["Error"] = authError.Error()
	}

	if err := tmpl.Execute(w, data); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
	}
}