2023-04-15 12:21:27 +02:00
|
|
|
import argparse
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import coremltools as ct
|
|
|
|
|
|
|
|
from torch import Tensor
|
|
|
|
from torch import nn
|
|
|
|
from typing import Dict
|
|
|
|
from typing import Optional
|
|
|
|
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
|
|
|
|
from coremltools.models.neural_network.quantization_utils import quantize_weights
|
|
|
|
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
|
|
|
from whisper import load_model
|
|
|
|
|
|
|
|
# Use for changing dim of input in encoder and decoder embeddings
|
|
|
|
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
|
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
|
|
"""
|
|
|
|
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
|
|
|
|
"""
|
|
|
|
for k in state_dict:
|
|
|
|
is_attention = all(substr in k for substr in ['attn', '.weight'])
|
2023-04-29 09:06:25 +02:00
|
|
|
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
|
|
|
|
state_dict[k] = state_dict[k][:, :, None, None]
|
|
|
|
|
|
|
|
|
|
|
|
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
|
|
|
|
strict, missing_keys,
|
|
|
|
unexpected_keys, error_msgs):
|
|
|
|
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
class LayerNormANE(LayerNormANEBase):
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._register_load_state_dict_pre_hook(
|
|
|
|
correct_for_bias_scale_order_inversion)
|
|
|
|
|
|
|
|
class MultiHeadAttentionANE(MultiHeadAttention):
|
|
|
|
def __init__(self, n_state: int, n_head: int):
|
|
|
|
super().__init__(n_state, n_head)
|
2023-04-29 09:06:25 +02:00
|
|
|
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
|
|
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
|
|
|
|
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
|
|
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
x: Tensor,
|
|
|
|
xa: Optional[Tensor] = None,
|
|
|
|
mask: Optional[Tensor] = None,
|
|
|
|
kv_cache: Optional[dict] = None):
|
|
|
|
|
|
|
|
q = self.query(x)
|
|
|
|
|
|
|
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
|
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
|
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
|
|
k = self.key(x if xa is None else xa)
|
|
|
|
v = self.value(x if xa is None else xa)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
|
|
k = kv_cache[self.key]
|
|
|
|
v = kv_cache[self.value]
|
|
|
|
|
|
|
|
wv, qk = self.qkv_attention_ane(q, k, v, mask)
|
|
|
|
|
|
|
|
return self.out(wv), qk
|
|
|
|
|
|
|
|
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
|
|
|
|
|
|
|
_, dim, _, seqlen = q.size()
|
|
|
|
|
|
|
|
dim_per_head = dim // self.n_head
|
|
|
|
|
|
|
|
scale = float(dim_per_head)**-0.5
|
|
|
|
|
|
|
|
q = q * scale
|
|
|
|
|
|
|
|
mh_q = q.split(dim_per_head, dim=1)
|
|
|
|
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
|
|
|
|
mh_v = v.split(dim_per_head, dim=1)
|
|
|
|
|
|
|
|
mh_qk = [
|
|
|
|
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
|
|
|
|
for qi, ki in zip(mh_q, mh_k)
|
|
|
|
] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
for head_idx in range(self.n_head):
|
|
|
|
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
|
|
|
|
|
|
|
|
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
|
|
|
|
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads
|
|
|
|
attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length)
|
|
|
|
|
|
|
|
return attn, torch.cat(mh_qk, dim=1).float().detach()
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualAttentionBlockANE(ResidualAttentionBlock):
|
|
|
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
|
|
super().__init__(n_state, n_head, cross_attention)
|
2023-04-29 09:06:25 +02:00
|
|
|
self.attn = MultiHeadAttentionANE(n_state, n_head)
|
|
|
|
self.attn_ln = LayerNormANE(n_state)
|
|
|
|
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
|
|
|
|
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
n_mlp = n_state * 4
|
2023-04-29 09:06:25 +02:00
|
|
|
self.mlp = nn.Sequential(
|
2023-04-15 12:21:27 +02:00
|
|
|
nn.Conv2d(n_state, n_mlp, kernel_size=1),
|
|
|
|
nn.GELU(),
|
|
|
|
nn.Conv2d(n_mlp, n_state, kernel_size=1)
|
2023-04-29 09:06:25 +02:00
|
|
|
)
|
|
|
|
self.mlp_ln = LayerNormANE(n_state)
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
|
|
|
|
class AudioEncoderANE(AudioEncoder):
|
|
|
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
|
|
|
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
|
|
|
|
2023-04-29 09:06:25 +02:00
|
|
|
self.blocks = nn.ModuleList(
|
2023-04-15 12:21:27 +02:00
|
|
|
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
|
2023-04-29 09:06:25 +02:00
|
|
|
)
|
|
|
|
self.ln_post = LayerNormANE(n_state)
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
"""
|
|
|
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
|
|
the mel spectrogram of the audio
|
|
|
|
"""
|
|
|
|
x = F.gelu(self.conv1(x))
|
|
|
|
x = F.gelu(self.conv2(x))
|
|
|
|
|
|
|
|
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
|
|
|
|
|
|
|
|
# Add positional embedding and add dummy dim for ANE
|
|
|
|
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
|
x = block(x)
|
|
|
|
|
|
|
|
x = self.ln_post(x)
|
|
|
|
|
|
|
|
# """
|
|
|
|
# TODO:
|
|
|
|
# I think we need to transpose the result here to make it fit whisper.cpp memory order.
|
|
|
|
# However, even doing this, the results are still wrong. Kind of less wrong compared to
|
|
|
|
# not transposing, but still wrong.
|
|
|
|
|
|
|
|
# Also, I don't know why the original OpenAI implementation does not need to transpose
|
|
|
|
|
|
|
|
# transpose to (batch_size, n_ctx, n_state)
|
|
|
|
# x : torch.Tensor, shape = (batch_size, n_state, 1, n_ctx)
|
|
|
|
|
|
|
|
# """
|
|
|
|
# x = x.transpose(1,3)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
class TextDecoderANE(TextDecoder):
|
|
|
|
|
|
|
|
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
|
|
|
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
|
|
|
|
2023-04-29 09:06:25 +02:00
|
|
|
self.blocks= nn.ModuleList(
|
2023-04-15 12:21:27 +02:00
|
|
|
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
2023-04-29 09:06:25 +02:00
|
|
|
)
|
|
|
|
self.ln= LayerNormANE(n_state)
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|
|
|
"""
|
|
|
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
|
|
|
the text tokens
|
|
|
|
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
|
|
the encoded audio features to be attended on
|
|
|
|
"""
|
|
|
|
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
|
|
|
|
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
|
|
|
x = x.to(xa.dtype)
|
|
|
|
|
|
|
|
# Reformat for ANE
|
|
|
|
mask = self.mask[None, None, :, :].permute(0,3,1,2)
|
|
|
|
x = x.transpose(1,2).unsqueeze(2)
|
|
|
|
|
|
|
|
for block in self.blocks:
|
|
|
|
x = block(x, xa, mask=mask, kv_cache=kv_cache)
|
|
|
|
|
|
|
|
x = self.ln(x)
|
|
|
|
|
|
|
|
# Reformat back from ANE
|
|
|
|
x = x.permute(0,2,3,1).squeeze(0)
|
|
|
|
|
|
|
|
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
|
2023-11-07 10:58:39 +01:00
|
|
|
if self.token_embedding.weight.shape[0] >= 51865:
|
2023-04-15 12:21:27 +02:00
|
|
|
# split in 11 chunks - 4715 each
|
|
|
|
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
|
|
|
|
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
|
|
|
|
else:
|
|
|
|
# split in 12 chunks - 4322 each
|
|
|
|
assert(self.token_embedding.weight.shape[0] == 51864)
|
|
|
|
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0)
|
|
|
|
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
class WhisperANE(Whisper):
|
|
|
|
def __init__(self, dims: ModelDimensions):
|
|
|
|
super().__init__(dims)
|
|
|
|
|
2023-04-29 09:06:25 +02:00
|
|
|
self.encoder = AudioEncoderANE(
|
2023-04-15 12:21:27 +02:00
|
|
|
self.dims.n_mels,
|
|
|
|
self.dims.n_audio_ctx,
|
|
|
|
self.dims.n_audio_state,
|
|
|
|
self.dims.n_audio_head,
|
|
|
|
self.dims.n_audio_layer,
|
2023-04-29 09:06:25 +02:00
|
|
|
)
|
|
|
|
self.decoder = TextDecoderANE(
|
2023-04-15 12:21:27 +02:00
|
|
|
self.dims.n_vocab,
|
|
|
|
self.dims.n_text_ctx,
|
|
|
|
self.dims.n_text_state,
|
|
|
|
self.dims.n_text_head,
|
|
|
|
self.dims.n_text_layer,
|
2023-04-29 09:06:25 +02:00
|
|
|
)
|
2023-04-15 12:21:27 +02:00
|
|
|
|
|
|
|
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
|
|
|
|
|
|
|
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
|
|
return self.decoder(tokens, self.encoder(mel))
|
|
|
|
|
|
|
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
|
|
cache = {**cache} if cache is not None else {}
|
|
|
|
hooks = []
|
|
|
|
|
|
|
|
def save_to_cache(module, _, output):
|
|
|
|
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
|
|
|
|
cache[module] = output # save as-is, for the first token or cross attention
|
|
|
|
else:
|
|
|
|
cache[module] = torch.cat([cache[module], output], dim=3).detach()
|
|
|
|
return cache[module]
|
|
|
|
|
|
|
|
def install_hooks(layer: nn.Module):
|
|
|
|
if isinstance(layer, MultiHeadAttentionANE):
|
|
|
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
|
|
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
|
|
|
|
|
|
|
self.decoder.apply(install_hooks)
|
|
|
|
return cache, hooks
|
|
|
|
|
|
|
|
def convert_encoder(hparams, model, quantize=False):
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
input_shape = (1, 80, 3000)
|
|
|
|
input_data = torch.randn(input_shape)
|
|
|
|
traced_model = torch.jit.trace(model, input_data)
|
|
|
|
|
|
|
|
model = ct.convert(
|
|
|
|
traced_model,
|
|
|
|
convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why
|
|
|
|
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
|
|
|
|
outputs=[ct.TensorType(name="output")],
|
|
|
|
compute_units=ct.ComputeUnit.ALL
|
|
|
|
)
|
|
|
|
|
|
|
|
if quantize:
|
|
|
|
model = quantize_weights(model, nbits=16)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
def convert_decoder(hparams, model, quantize=False):
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
tokens_shape = (1, 1)
|
|
|
|
audio_shape = (1, hparams.n_audio_state, 1, 1500)
|
|
|
|
|
|
|
|
audio_data = torch.randn(audio_shape)
|
|
|
|
token_data = torch.randint(50257, tokens_shape).long()
|
|
|
|
traced_model = torch.jit.trace(model, (token_data, audio_data))
|
|
|
|
|
|
|
|
model = ct.convert(
|
|
|
|
traced_model,
|
|
|
|
convert_to=None if quantize else "mlprogram", # convert will fail if weights are quantized, not sure why
|
|
|
|
inputs=[
|
|
|
|
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
|
|
|
|
ct.TensorType(name="audio_data", shape=audio_shape)
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
if quantize:
|
|
|
|
model = quantize_weights(model, nbits=16)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
2023-11-07 10:58:39 +01:00
|
|
|
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
|
2023-04-15 12:21:27 +02:00
|
|
|
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
|
|
|
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
|
|
|
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-11-07 10:58:39 +01:00
|
|
|
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
2023-04-15 12:21:27 +02:00
|
|
|
raise ValueError("Invalid model name")
|
|
|
|
|
|
|
|
whisper = load_model(args.model).cpu()
|
|
|
|
hparams = whisper.dims
|
|
|
|
print(hparams)
|
|
|
|
|
|
|
|
if args.optimize_ane:
|
|
|
|
whisperANE = WhisperANE(hparams).eval()
|
|
|
|
whisperANE.load_state_dict(whisper.state_dict())
|
|
|
|
|
|
|
|
encoder = whisperANE.encoder
|
|
|
|
decoder = whisperANE.decoder
|
|
|
|
else:
|
|
|
|
encoder = whisper.encoder
|
|
|
|
decoder = whisper.decoder
|
|
|
|
|
|
|
|
# Convert encoder
|
|
|
|
encoder = convert_encoder(hparams, encoder, quantize=args.quantize)
|
|
|
|
encoder.save(f"models/coreml-encoder-{args.model}.mlpackage")
|
|
|
|
|
|
|
|
if args.encoder_only is False:
|
|
|
|
# Convert decoder
|
|
|
|
decoder = convert_decoder(hparams, decoder, quantize=args.quantize)
|
|
|
|
decoder.save(f"models/coreml-decoder-{args.model}.mlpackage")
|
|
|
|
|
|
|
|
print("done converting")
|