fix potential bug reading model data into a small size optimized string which could lead to memory corruption. In an SSO string, you can't write data to &str[0] and expect it to work well.

Also added a small wrapper function to more safely read model data without having to get the sizeof right. I tested this on tiny, base and large models, there was no change in behaviour.
This commit is contained in:
bert hubert 2022-12-10 13:09:31 +01:00 committed by Georgi Gerganov
parent 603f97ba11
commit d1da35de06

View File

@ -429,6 +429,12 @@ struct whisper_context {
int32_t exp_n_audio_ctx; // 0 - use default
};
template<typename T>
static void read_safe(std::ifstream& fin, T& dest)
{
fin.read((char*)& dest, sizeof(T));
}
// load the model from a ggml file
//
// file format:
@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
read_safe(fin, magic);
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
read_safe(fin, hparams.n_vocab);
read_safe(fin, hparams.n_audio_ctx);
read_safe(fin, hparams.n_audio_state);
read_safe(fin, hparams.n_audio_head);
read_safe(fin, hparams.n_audio_layer);
read_safe(fin, hparams.n_text_ctx);
read_safe(fin, hparams.n_text_state);
read_safe(fin, hparams.n_text_head);
read_safe(fin, hparams.n_text_layer);
read_safe(fin, hparams.n_mels);
read_safe(fin, hparams.f16);
assert(hparams.n_text_state == hparams.n_audio_state);
@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
auto & filters = wctx.model.filters;
fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
read_safe(fin, filters.n_mel);
read_safe(fin, filters.n_fft);
filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
read_safe(fin, n_vocab);
//if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
read_safe(fin, len);
word.resize(len);
fin.read((char *) word.data(), len);
std::vector<char> tmp(len); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
word.assign(&tmp[0], tmp.size());
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length;
int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
read_safe(fin, n_dims);
read_safe(fin, length);
read_safe(fin, ftype);
if (fin.eof()) {
break;
@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
read_safe(fin, ne[i]);
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
std::string name;
std::vector<char> tmp(length); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());