package main

import (
	"fmt"
	"io"
	"os"
	"time"

	// Package imports
	whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
	wav "github.com/go-audio/wav"
)

func Process(model whisper.Model, path string, flags *Flags) error {
	var data []float32

	// Create processing context
	context, err := model.NewContext()
	if err != nil {
		return err
	}

	// Set the parameters
	if err := flags.SetParams(context); err != nil {
		return err
	}

	fmt.Printf("\n%s\n", context.SystemInfo())

	// Open the file
	fmt.Fprintf(flags.Output(), "Loading %q\n", path)
	fh, err := os.Open(path)
	if err != nil {
		return err
	}
	defer fh.Close()

	// Decode the WAV file - load the full buffer
	dec := wav.NewDecoder(fh)
	if buf, err := dec.FullPCMBuffer(); err != nil {
		return err
	} else if dec.SampleRate != whisper.SampleRate {
		return fmt.Errorf("unsupported sample rate: %d", dec.SampleRate)
	} else if dec.NumChans != 1 {
		return fmt.Errorf("unsupported number of channels: %d", dec.NumChans)
	} else {
		data = buf.AsFloat32Buffer().Data
	}

	// Segment callback when -tokens is specified
	var cb whisper.SegmentCallback
	if flags.IsTokens() {
		cb = func(segment whisper.Segment) {
			fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
			for _, token := range segment.Tokens {
				if flags.IsColorize() && context.IsText(token) {
					fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
				} else {
					fmt.Fprint(flags.Output(), token.Text, " ")
				}
			}
			fmt.Fprintln(flags.Output(), "")
			fmt.Fprintln(flags.Output(), "")
		}
	}

	// Process the data
	fmt.Fprintf(flags.Output(), "  ...processing %q\n", path)
	context.ResetTimings()
	if err := context.Process(data, cb, nil); err != nil {
		return err
	}

	context.PrintTimings()

	// Print out the results
	switch {
	case flags.GetOut() == "srt":
		return OutputSRT(os.Stdout, context)
	case flags.GetOut() == "none":
		return nil
	default:
		return Output(os.Stdout, context, flags.IsColorize())
	}
}

// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
	n := 1
	for {
		segment, err := context.NextSegment()
		if err == io.EOF {
			return nil
		} else if err != nil {
			return err
		}
		fmt.Fprintln(w, n)
		fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
		fmt.Fprintln(w, segment.Text)
		fmt.Fprintln(w, "")
		n++
	}
}

// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
	for {
		segment, err := context.NextSegment()
		if err == io.EOF {
			return nil
		} else if err != nil {
			return err
		}
		fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
		if colorize {
			for _, token := range segment.Tokens {
				if !context.IsText(token) {
					continue
				}
				fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
			}
			fmt.Fprint(w, "\n")
		} else {
			fmt.Fprintln(w, " ", segment.Text)
		}
	}
}

// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
	return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}