Expose more ctx->vocab interfaces.

I need these functions to implement a kind of weighting coefficient
logits_filter_callback like:

```

void filter_callback(
            struct whisper_context * ctx,
              struct whisper_state * state,
          const whisper_token_data * tokens,
                               int   n_tokens,
                             float * logits,
                              void * user_data
) {
    const static std::vector<std::string> good_words = {
        "音声", "認識"
    };
    std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> conv;
    auto prev = n_tokens > 0 ? std::string(whisper_token_to_str(ctx, tokens[n_tokens - 1].id)) : "";

    for (const std::string & token : good_words) {
        auto s32 = conv.from_bytes(token);
        auto s0 = conv.to_bytes(s32[0]);
        auto s1 = conv.to_bytes(s32[1]);

        if (whisper_token_exists(ctx, token.c_str())) {
            logits[whisper_str_to_token(ctx, token.c_str())] *= 2;
        } else if (
            prev.size() >= s0.size()
            && prev.compare(prev.size() - s0.size(), s0.size(), s0) == 0
            && whisper_token_exists(ctx, s1.c_str())
        ) {
            logits[whisper_str_to_token(ctx, s1.c_str())] *= 1.6;
        } else if (whisper_token_exists(ctx, s0.c_str())) {
            logits[whisper_str_to_token(ctx, s0.c_str())] *= 1.2;
        }
    }
}
```
This commit is contained in:
Tamotsu Takahashi 2025-01-01 23:44:29 +09:00
parent e4e05981d6
commit d0f38def08
2 changed files with 11 additions and 0 deletions

View File

@ -408,6 +408,9 @@ extern "C" {
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx);
// String -> Token Id. Uses the vocabulary in the provided context
WHISPER_API bool whisper_token_exists(struct whisper_context * ctx, const char * str);
WHISPER_API whisper_token whisper_str_to_token(struct whisper_context * ctx, const char * str);
// Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);

View File

@ -4068,6 +4068,14 @@ const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token to
return ctx->vocab.id_to_token.at(token).c_str();
}
whisper_token whisper_str_to_token(struct whisper_context * ctx, const char * str) {
return ctx->vocab.token_to_id.at(str);
}
bool whisper_token_exists(struct whisper_context * ctx, const char * str) {
return ctx->vocab.token_to_id.find(str) != ctx->vocab.token_to_id.end();
}
whisper_token whisper_token_eot(struct whisper_context * ctx) {
return ctx->vocab.token_eot;
}