ref #17 : add options to output result to file

Support for:

- plain text
- VTT
- SRT
This commit is contained in:
Georgi Gerganov 2022-10-08 17:22:22 +03:00
parent 4c4ab71d4d
commit 8c7c018893
2 changed files with 92 additions and 8 deletions

View File

@ -5,6 +5,7 @@
#define DR_WAV_IMPLEMENTATION #define DR_WAV_IMPLEMENTATION
#include "dr_wav.h" #include "dr_wav.h"
#include <fstream>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <thread> #include <thread>
@ -32,6 +33,9 @@ struct whisper_params {
bool verbose = false; bool verbose = false;
bool translate = false; bool translate = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool print_special_tokens = false; bool print_special_tokens = false;
bool no_timestamps = false; bool no_timestamps = false;
@ -69,6 +73,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
exit(0); exit(0);
} }
} else if (arg == "-otxt" || arg == "--output-txt") {
params.output_txt = true;
} else if (arg == "-ovtt" || arg == "--output-vtt") {
params.output_vtt = true;
} else if (arg == "-osrt" || arg == "--output-srt") {
params.output_srt = true;
} else if (arg == "-ps" || arg == "--print_special") { } else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true; params.print_special_tokens = true;
} else if (arg == "-nt" || arg == "--no_timestamps") { } else if (arg == "-nt" || arg == "--no_timestamps") {
@ -101,6 +111,8 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, " -o N, --offset N offset in milliseconds (default: %d)\n", params.offset_ms); fprintf(stderr, " -o N, --offset N offset in milliseconds (default: %d)\n", params.offset_ms);
fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
@ -123,7 +135,7 @@ int main(int argc, char ** argv) {
if (params.fname_inp.empty()) { if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n"); fprintf(stderr, "error: no input files specified\n");
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
return 1; return 2;
} }
// whisper init // whisper init
@ -140,22 +152,22 @@ int main(int argc, char ** argv) {
if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) { if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) {
fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str());
whisper_print_usage(argc, argv, {}); whisper_print_usage(argc, argv, {});
return 2; return 3;
} }
if (wav.channels != 1 && wav.channels != 2) { if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str());
return 3; return 4;
} }
if (wav.sampleRate != WHISPER_SAMPLE_RATE) { if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
return 4; return 5;
} }
if (wav.bitsPerSample != 16) { if (wav.bitsPerSample != 16) {
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str());
return 5; return 6;
} }
int n = wav.totalPCMFrameCount; int n = wav.totalPCMFrameCount;
@ -193,9 +205,11 @@ int main(int argc, char ** argv) {
params.language.c_str(), params.language.c_str(),
params.translate ? "translate" : "transcribe", params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1); params.no_timestamps ? 0 : 1);
printf("\n"); printf("\n");
} }
// run the inference // run the inference
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY);
@ -211,10 +225,10 @@ int main(int argc, char ** argv) {
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6; return 7;
} }
// print result; // print result
if (!wparams.print_realtime) { if (!wparams.print_realtime) {
printf("\n"); printf("\n");
@ -233,6 +247,76 @@ int main(int argc, char ** argv) {
} }
} }
} }
printf("\n");
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_inp + ".txt";
std::ofstream fout_txt(fname_txt);
if (!fout_txt.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_txt.c_str());
return 8;
}
printf("%s: saving output to '%s.txt'\n", __func__, fname_inp.c_str());
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
fout_txt << text;
}
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_inp + ".vtt";
std::ofstream fout_vtt(fname_vtt);
if (!fout_vtt.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_vtt.c_str());
return 9;
}
printf("%s: saving output to '%s.vtt'\n", __func__, fname_inp.c_str());
fout_vtt << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
fout_vtt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout_vtt << text << "\n\n";
}
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_inp + ".srt";
std::ofstream fout_srt(fname_srt);
if (!fout_srt.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_srt.c_str());
return 10;
}
printf("%s: saving output to '%s.srt'\n", __func__, fname_inp.c_str());
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
fout_srt << i + 1 << "\n";
fout_srt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout_srt << text << "\n\n";
}
}
} }
} }

View File

@ -2242,7 +2242,7 @@ whisper_token whisper_token_transcribe() {
void whisper_print_timings(struct whisper_context * ctx) { void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = ggml_time_us(); const int64_t t_end_us = ggml_time_us();
printf("\n\n"); printf("\n");
printf("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); printf("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
printf("%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); printf("%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
printf("%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f); printf("%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);