mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-22 10:18:49 +02:00
models : minor changes to the HF convert script (#157)
This commit is contained in:
parent
93482d0373
commit
d64d6ca3fd
@ -28,6 +28,7 @@ conv_map = {'self_attn_layer_norm': 'attn_ln',
|
|||||||
'decoder.layer_norm.weight': 'decoder.ln.weight',
|
'decoder.layer_norm.weight': 'decoder.ln.weight',
|
||||||
'decoder.embed_positions.weight': 'decoder.positional_embedding',
|
'decoder.embed_positions.weight': 'decoder.positional_embedding',
|
||||||
'decoder.embed_tokens.weight': 'decoder.token_embedding.weight',
|
'decoder.embed_tokens.weight': 'decoder.token_embedding.weight',
|
||||||
|
'proj_out.weight': 'decoder.proj.weight',
|
||||||
}
|
}
|
||||||
|
|
||||||
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
@ -82,8 +83,11 @@ fname_out = dir_out + "/ggml-model.bin"
|
|||||||
with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
|
with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
|
||||||
tokens = json.load(f)
|
tokens = json.load(f)
|
||||||
|
|
||||||
|
# use 16-bit or 32-bit floats
|
||||||
use_f16 = True
|
use_f16 = True
|
||||||
|
if len(sys.argv) > 4:
|
||||||
|
use_f16 = False
|
||||||
|
fname_out = dir_out + "/ggml-model-f32.bin"
|
||||||
|
|
||||||
fout = open(fname_out, "wb")
|
fout = open(fname_out, "wb")
|
||||||
|
|
||||||
@ -119,6 +123,8 @@ for key in tokens:
|
|||||||
|
|
||||||
list_vars = model.state_dict()
|
list_vars = model.state_dict()
|
||||||
for name in list_vars.keys():
|
for name in list_vars.keys():
|
||||||
|
# this seems to not be used
|
||||||
|
# ref: https://github.com/huggingface/transformers/blob/9a5b84a0076a04fe9596da72e8668069d4f09ea0/src/transformers/models/whisper/modeling_whisper.py#L1099-L1106
|
||||||
if name == "proj_out.weight":
|
if name == "proj_out.weight":
|
||||||
print('Skipping', name)
|
print('Skipping', name)
|
||||||
continue
|
continue
|
||||||
@ -126,7 +132,11 @@ for name in list_vars.keys():
|
|||||||
src = name
|
src = name
|
||||||
|
|
||||||
nn = name
|
nn = name
|
||||||
nn = nn.split(".")[1:]
|
if name != "proj_out.weight":
|
||||||
|
nn = nn.split(".")[1:]
|
||||||
|
else:
|
||||||
|
nn = nn.split(".")
|
||||||
|
|
||||||
if nn[1] == "layers":
|
if nn[1] == "layers":
|
||||||
nn[1] = "blocks"
|
nn[1] = "blocks"
|
||||||
if ".".join(nn[3:-1]) == "self_attn.k_proj":
|
if ".".join(nn[3:-1]) == "self_attn.k_proj":
|
||||||
|
Loading…
Reference in New Issue
Block a user