require_relative "helper"

class TestCallback < TestBase
  def setup
    GC.start
    @params = Whisper::Params.new
    @whisper = Whisper::Context.new("base.en")
    @audio = File.join(AUDIO)
  end

  def test_new_segment_callback
    @params.new_segment_callback = ->(context, state, n_new, user_data) {
      assert_kind_of Integer, n_new
      assert n_new > 0
      assert_same @whisper, context

      n_segments = context.full_n_segments
      n_new.times do |i|
        i_segment = n_segments - 1 + i
        start_time = context.full_get_segment_t0(i_segment) * 10
        end_time = context.full_get_segment_t1(i_segment) * 10
        text = context.full_get_segment_text(i_segment)

        assert_kind_of Integer, start_time
        assert start_time >= 0
        assert_kind_of Integer, end_time
        assert end_time > 0
        assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
      end
    }

    @whisper.transcribe(@audio, @params)
  end

  def test_new_segment_callback_closure
    search_word = "what"
    @params.new_segment_callback = ->(context, state, n_new, user_data) {
      n_segments = context.full_n_segments
      n_new.times do |i|
        i_segment = n_segments - 1 + i
        text = context.full_get_segment_text(i_segment)
        if text.include?(search_word)
          t0 = context.full_get_segment_t0(i_segment)
          t1 = context.full_get_segment_t1(i_segment)
          raise "search word '#{search_word}' found at between #{t0} and #{t1}"
        end
      end
    }

    assert_raise RuntimeError do
      @whisper.transcribe(@audio, @params)
    end
  end

  def test_new_segment_callback_user_data
    udata = Object.new
    @params.new_segment_callback_user_data = udata
    @params.new_segment_callback = ->(context, state, n_new, user_data) {
      assert_same udata, user_data
    }

    @whisper.transcribe(@audio, @params)
  end

  def test_new_segment_callback_user_data_gc
    @params.new_segment_callback_user_data = "My user data"
    @params.new_segment_callback = ->(context, state, n_new, user_data) {
      assert_equal "My user data", user_data
    }
    GC.start

    assert_same @whisper, @whisper.transcribe(@audio, @params)
  end

  def test_progress_callback
    first = nil
    last = nil
    @params.progress_callback = ->(context, state, progress, user_data) {
      assert_kind_of Integer, progress
      assert 0 <= progress && progress <= 100
      assert_same @whisper, context
      first = progress if first.nil?
      last = progress
    }
    @whisper.transcribe(@audio, @params)
    assert_equal 0, first
    assert_equal 100, last
  end

  def test_progress_callback_user_data
    udata = Object.new
    @params.progress_callback_user_data = udata
    @params.progress_callback = ->(context, state, n_new, user_data) {
      assert_same udata, user_data
    }

    @whisper.transcribe(@audio, @params)
  end

  def test_on_progress
    first = nil
    last = nil
    @params.on_progress do |progress|
      assert_kind_of Integer, progress
      assert 0 <= progress && progress <= 100
      first = progress if first.nil?
      last = progress
    end
    @whisper.transcribe(@audio, @params)
    assert_equal 0, first
    assert_equal 100, last
  end

  def test_abort_callback
    i = 0
    @params.abort_callback = ->(user_data) {
      assert_nil user_data
      i += 1
      return false
    }
    @whisper.transcribe(@audio, @params)
    assert i > 0
  end

  def test_abort_callback_abort
    i = 0
    @params.abort_callback = ->(user_data) {
      i += 1
      return i == 3
    }
    @whisper.transcribe(@audio, @params)
    assert_equal 3, i
  end

  def test_abort_callback_user_data
    udata = Object.new
    @params.abort_callback_user_data = udata
    yielded = nil
    @params.abort_callback = ->(user_data) {
      yielded = user_data
    }
    @whisper.transcribe(@audio, @params)
    assert_same udata, yielded
  end

  def test_abort_on
    do_abort = false
    aborted_from_callback = false
    @params.on_new_segment do |segment|
      do_abort = true if segment.text.match? /ask/
    end
    i = 0
    @params.abort_on do
      i += 1
      do_abort
    end
    @whisper.transcribe(@audio, @params)
    assert i > 0
  end
end