mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
py : make convert-pt-to-ggml.py backwards compatible with older vocab.json tokenizer files (#1001)
* patch checkpoint convert script to keep compatibility with older hf_transformers whisper tokenizer * typo fix
This commit is contained in:
parent
a7f822ef59
commit
3ec7bfffe0
@ -224,16 +224,39 @@ with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
|
||||
|
||||
#code.interact(local=locals())
|
||||
|
||||
# load tokenizer
|
||||
# for backwards compatibility, also check for older hf_transformers format tokenizer files
|
||||
# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
|
||||
# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
|
||||
multilingual = hparams["n_vocab"] == 51865
|
||||
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
||||
tokenizer_type = "tiktoken"
|
||||
if not tokenizer.is_file():
|
||||
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
|
||||
tokenizer_type = "hf_transformers"
|
||||
if not tokenizer.is_file():
|
||||
print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
|
||||
sys.exit(1)
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||
|
||||
if tokenizer_type == "tiktoken":
|
||||
with open(tokenizer, "rb") as f:
|
||||
contents = f.read()
|
||||
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
|
||||
elif tokenizer_type == "hf_transformers":
|
||||
with open(tokenizer, "r", encoding="utf8") as f:
|
||||
_tokens_raw = json.load(f)
|
||||
if '<|endoftext|>' in _tokens_raw:
|
||||
# ensures exact same model as tokenizer_type == tiktoken
|
||||
# details: https://github.com/ggerganov/whisper.cpp/pull/725
|
||||
del _tokens_raw['<|endoftext|>']
|
||||
tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
|
||||
|
||||
# output in the same directory as the model
|
||||
fname_out = dir_out / "ggml-model.bin"
|
||||
|
||||
with open(tokenizer, "rb") as f:
|
||||
contents = f.read()
|
||||
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
|
||||
|
||||
# use 16-bit or 32-bit floats
|
||||
use_f16 = True
|
||||
if len(sys.argv) > 4:
|
||||
@ -262,9 +285,7 @@ for i in range(filters.shape[0]):
|
||||
for j in range(filters.shape[1]):
|
||||
fout.write(struct.pack("f", filters[i][j]))
|
||||
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v:k for k, v in byte_encoder.items()}
|
||||
|
||||
# write tokenizer
|
||||
fout.write(struct.pack("i", len(tokens)))
|
||||
|
||||
for key in tokens:
|
||||
|
Loading…
Reference in New Issue
Block a user