mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 16:44:13 +01:00
examples : small code cleanups (#322)
- remove unnecessary initialization of string to "" - use empty() instead of checking size() - use emplace_back instead of push_back - use nullptr instead of NULL - remove unnecessary call to .data() on string - use character overload of find_first_of() instead of passing a string
This commit is contained in:
parent
7282e2109e
commit
dc90efd504
@ -41,8 +41,8 @@ struct whisper_params {
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out = "";
|
||||
std::string commands = "";
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -576,10 +576,10 @@ int main(int argc, char ** argv) {
|
||||
std::vector<std::string> allowed_commands;
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
std::string k_prompt = "";
|
||||
std::string k_prompt;
|
||||
std::vector<whisper_token> k_tokens;
|
||||
|
||||
if (params.commands != "") {
|
||||
if (!params.commands.empty()) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
@ -808,7 +808,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i));
|
||||
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ struct whisper_params {
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
std::string prompt;
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
@ -118,7 +118,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -206,7 +206,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
std::string speaker = "";
|
||||
std::string speaker;
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
@ -468,7 +468,7 @@ int main(int argc, char ** argv) {
|
||||
// initial prompt
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
if (params.prompt.size() > 0) {
|
||||
if (!params.prompt.empty()) {
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
@ -505,14 +505,14 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), NULL) == false) {
|
||||
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||||
return 4;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||||
}
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
|
||||
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
|
||||
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
|
||||
return 5;
|
||||
}
|
||||
@ -617,8 +617,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
|
||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
|
@ -51,7 +51,7 @@ struct whisper_params {
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out = "";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
|
@ -40,7 +40,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<gpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
if (word.empty()) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
@ -86,7 +86,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits_id.push_back(std::make_pair(logits[i], i));
|
||||
logits_id.emplace_back(logits[i], i);
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
@ -327,7 +327,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
{
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = NULL;
|
||||
params.mem_buffer = nullptr;
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
@ -448,7 +448,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
if (model.tensors.find(name) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
@ -833,7 +833,7 @@ Me too.
|
||||
struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
gpt2_context * ctx = new gpt2_context;
|
||||
|
||||
ctx->rng = std::mt19937(time(NULL));
|
||||
ctx->rng = std::mt19937(time(nullptr));
|
||||
|
||||
// load the model
|
||||
{
|
||||
@ -886,7 +886,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
||||
|
||||
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
if (!embd.empty()) {
|
||||
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
|
||||
printf("gpt-2: failed to generate text\n");
|
||||
return "";
|
||||
|
@ -39,7 +39,7 @@ struct whisper_params {
|
||||
std::string model_wsp = "models/ggml-base.en.bin";
|
||||
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
|
||||
std::string speak = "./examples/talk/speak.sh";
|
||||
std::string fname_out = "";
|
||||
std::string fname_out;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard = "";
|
||||
std::string text_heard;
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
|
||||
@ -610,7 +610,7 @@ int main(int argc, char ** argv) {
|
||||
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
|
||||
// take first line
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
|
||||
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
||||
|
||||
// remove leading and trailing whitespace
|
||||
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
||||
@ -640,18 +640,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
|
||||
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
|
||||
|
||||
// remove first 2 lines of base prompt
|
||||
if (n_iter > 4) {
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of("\n");
|
||||
const size_t pos = prompt_base.find_first_of('\n');
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
}
|
||||
{
|
||||
const size_t pos = prompt_base.find_first_of("\n");
|
||||
const size_t pos = prompt_base.find_first_of('\n');
|
||||
if (pos != std::string::npos) {
|
||||
prompt_base = prompt_base.substr(pos + 1);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user