mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-22 18:29:00 +02:00
* coreml: fix Whisper to CoreML conversion by disabling SDPA This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: https://github.com/ggerganov/whisper.cpp/issues/2783 * coreml: fix audio shape in whisper decoder conversion This commit fixes the audio shape in the whisper decoder conversion script. The motivation for this is that the audio shape was incorrect and was causing the conversion to fail. * coreml : set -e in generate-coreml-interface.sh The commit sets the -e flag in the generate-coreml-interface.sh script to make sure the script fails if any command fails. * coreml : update generated encoder/decoder interfaces This commit updates the generated encoder/decoder interfaces for the whisper model which is the result of running the generate-coreml-interface.sh script.
329 lines
13 KiB
Python
329 lines
13 KiB
Python
import argparse
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import coremltools as ct
|
|
|
|
from torch import Tensor
|
|
from torch import nn
|
|
from typing import Dict
|
|
from typing import Optional
|
|
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
|
|
from coremltools.models.neural_network.quantization_utils import quantize_weights
|
|
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
|
from whisper import load_model
|
|
|
|
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
|
|
# The Whisper implementation expects a specific behavior from
|
|
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
|
|
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
|
|
# implementation instead, which is more stable across different PyTorch versions
|
|
# (2.5.0 required by coremltools vs newer versions).
|
|
import whisper.model
|
|
whisper.model.MultiHeadAttention.use_sdpa = False
|
|
|
|
# Use for changing dim of input in encoder and decoder embeddings
|
|
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
"""
|
|
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
|
|
"""
|
|
for k in state_dict:
|
|
is_attention = all(substr in k for substr in ['attn', '.weight'])
|
|
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
|
|
|
|
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
|
|
state_dict[k] = state_dict[k][:, :, None, None]
|
|
|
|
|
|
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
|
|
strict, missing_keys,
|
|
unexpected_keys, error_msgs):
|
|
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
|
|
return state_dict
|
|
|
|
class LayerNormANE(LayerNormANEBase):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._register_load_state_dict_pre_hook(
|
|
correct_for_bias_scale_order_inversion)
|
|
|
|
class MultiHeadAttentionANE(MultiHeadAttention):
|
|
def __init__(self, n_state: int, n_head: int):
|
|
super().__init__(n_state, n_head)
|
|
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
|
|
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
|
|
|
|
def forward(self,
|
|
x: Tensor,
|
|
xa: Optional[Tensor] = None,
|
|
mask: Optional[Tensor] = None,
|
|
kv_cache: Optional[dict] = None):
|
|
|
|
q = self.query(x)
|
|
|
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
|
k = self.key(x if xa is None else xa)
|
|
v = self.value(x if xa is None else xa)
|
|
|
|
else:
|
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
|
k = kv_cache[self.key]
|
|
v = kv_cache[self.value]
|
|
|
|
wv, qk = self.qkv_attention_ane(q, k, v, mask)
|
|
|
|
return self.out(wv), qk
|
|
|
|
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
|
|
|
_, dim, _, seqlen = q.size()
|
|
|
|
dim_per_head = dim // self.n_head
|
|
|
|
scale = float(dim_per_head)**-0.5
|
|
|
|
q = q * scale
|
|
|
|
mh_q = q.split(dim_per_head, dim=1)
|
|
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
|
|
mh_v = v.split(dim_per_head, dim=1)
|
|
|
|
mh_qk = [
|
|
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
|
|
for qi, ki in zip(mh_q, mh_k)
|
|
] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
|
|
|
|
if mask is not None:
|
|
for head_idx in range(self.n_head):
|
|
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
|
|
|
|
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
|
|
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads
|
|
attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length)
|
|
|
|
return attn, torch.cat(mh_qk, dim=1).float().detach()
|
|
|
|
|
|
class ResidualAttentionBlockANE(ResidualAttentionBlock):
|
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
|
super().__init__(n_state, n_head, cross_attention)
|
|
self.attn = MultiHeadAttentionANE(n_state, n_head)
|
|
self.attn_ln = LayerNormANE(n_state)
|
|
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
|
|
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
|
|
|
|
n_mlp = n_state * 4
|
|
self.mlp = nn.Sequential(
|
|
nn.Conv2d(n_state, n_mlp, kernel_size=1),
|
|
nn.GELU(),
|
|
nn.Conv2d(n_mlp, n_state, kernel_size=1)
|
|
)
|
|
self.mlp_ln = LayerNormANE(n_state)
|
|
|
|
|
|
class AudioEncoderANE(AudioEncoder):
|
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
|
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
|
|
|
self.blocks = nn.ModuleList(
|
|
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
|
|
)
|
|
self.ln_post = LayerNormANE(n_state)
|
|
|
|
def forward(self, x: Tensor):
|
|
"""
|
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
|
the mel spectrogram of the audio
|
|
"""
|
|
x = F.gelu(self.conv1(x))
|
|
x = F.gelu(self.conv2(x))
|
|
|
|
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
|
|
|
|
# Add positional embedding and add dummy dim for ANE
|
|
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
|
|
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
x = self.ln_post(x)
|
|
x = x.squeeze(2).transpose(1, 2)
|
|
|
|
return x
|
|
|
|
class TextDecoderANE(TextDecoder):
|
|
|
|
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
|
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
|
|
|
self.blocks= nn.ModuleList(
|
|
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
|
)
|
|
self.ln= LayerNormANE(n_state)
|
|
|
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
|
"""
|
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
|
the text tokens
|
|
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
|
the encoded audio features to be attended on
|
|
"""
|
|
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
|
|
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
|
x = x.to(xa.dtype)
|
|
|
|
# Reformat for ANE
|
|
mask = self.mask[None, None, :, :].permute(0,3,1,2)
|
|
x = x.transpose(1,2).unsqueeze(2)
|
|
|
|
for block in self.blocks:
|
|
x = block(x, xa, mask=mask, kv_cache=kv_cache)
|
|
|
|
x = self.ln(x)
|
|
|
|
# Reformat back from ANE
|
|
x = x.permute(0,2,3,1).squeeze(0)
|
|
|
|
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
|
|
if self.token_embedding.weight.shape[0] >= 51865:
|
|
# split in 11 chunks - 4715 each
|
|
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
|
|
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
|
|
else:
|
|
# split in 12 chunks - 4322 each
|
|
assert(self.token_embedding.weight.shape[0] == 51864)
|
|
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0)
|
|
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
|
|
|
|
return logits
|
|
|
|
class WhisperANE(Whisper):
|
|
def __init__(self, dims: ModelDimensions):
|
|
super().__init__(dims)
|
|
|
|
self.encoder = AudioEncoderANE(
|
|
self.dims.n_mels,
|
|
self.dims.n_audio_ctx,
|
|
self.dims.n_audio_state,
|
|
self.dims.n_audio_head,
|
|
self.dims.n_audio_layer,
|
|
)
|
|
self.decoder = TextDecoderANE(
|
|
self.dims.n_vocab,
|
|
self.dims.n_text_ctx,
|
|
self.dims.n_text_state,
|
|
self.dims.n_text_head,
|
|
self.dims.n_text_layer,
|
|
)
|
|
|
|
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
|
|
|
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
return self.decoder(tokens, self.encoder(mel))
|
|
|
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
|
cache = {**cache} if cache is not None else {}
|
|
hooks = []
|
|
|
|
def save_to_cache(module, _, output):
|
|
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
|
|
cache[module] = output # save as-is, for the first token or cross attention
|
|
else:
|
|
cache[module] = torch.cat([cache[module], output], dim=3).detach()
|
|
return cache[module]
|
|
|
|
def install_hooks(layer: nn.Module):
|
|
if isinstance(layer, MultiHeadAttentionANE):
|
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
|
|
|
self.decoder.apply(install_hooks)
|
|
return cache, hooks
|
|
|
|
def convert_encoder(hparams, model, quantize=False):
|
|
model.eval()
|
|
|
|
input_shape = (1, hparams.n_mels, 3000)
|
|
input_data = torch.randn(input_shape)
|
|
traced_model = torch.jit.trace(model, input_data)
|
|
|
|
model = ct.convert(
|
|
traced_model,
|
|
convert_to="neuralnetwork",
|
|
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
|
|
outputs=[ct.TensorType(name="output")],
|
|
compute_units=ct.ComputeUnit.ALL
|
|
)
|
|
|
|
if quantize:
|
|
model = quantize_weights(model, nbits=16)
|
|
|
|
return model
|
|
|
|
def convert_decoder(hparams, model, quantize=False):
|
|
model.eval()
|
|
|
|
tokens_shape = (1, 1)
|
|
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
|
|
|
|
audio_data = torch.randn(audio_shape)
|
|
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
|
|
|
|
traced_model = torch.jit.trace(model, (token_data, audio_data))
|
|
|
|
model = ct.convert(
|
|
traced_model,
|
|
convert_to="neuralnetwork",
|
|
inputs=[
|
|
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
|
|
ct.TensorType(name="audio_data", shape=audio_shape)
|
|
]
|
|
)
|
|
|
|
if quantize:
|
|
model = quantize_weights(model, nbits=16)
|
|
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True)
|
|
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
|
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
|
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
|
args = parser.parse_args()
|
|
|
|
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]:
|
|
raise ValueError("Invalid model name")
|
|
|
|
whisper = load_model(args.model).cpu()
|
|
hparams = whisper.dims
|
|
print(hparams)
|
|
|
|
if args.optimize_ane:
|
|
whisperANE = WhisperANE(hparams).eval()
|
|
whisperANE.load_state_dict(whisper.state_dict())
|
|
|
|
encoder = whisperANE.encoder
|
|
decoder = whisperANE.decoder
|
|
else:
|
|
encoder = whisper.encoder
|
|
decoder = whisper.decoder
|
|
|
|
# Convert encoder
|
|
encoder = convert_encoder(hparams, encoder, quantize=args.quantize)
|
|
encoder.save(f"models/coreml-encoder-{args.model}.mlpackage")
|
|
|
|
if args.encoder_only is False:
|
|
# Convert decoder
|
|
decoder = convert_decoder(hparams, decoder, quantize=args.quantize)
|
|
decoder.save(f"models/coreml-decoder-{args.model}.mlpackage")
|
|
|
|
print("done converting")
|