whisper.cpp/models/convert-whisper-to-coreml.py
Daniel Bevenius 11688b262f
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] (#2979)
* coreml: fix Whisper to CoreML conversion by disabling SDPA

This commit disables the use of PyTorch's
`scaled_dot_product_attention` in the Whisper model to avoid
compatibility issues during CoreML conversion.
The issue occurs because coremltools requires PyTorch 2.5.0, but the
Whisper implementation may expect behavior from newer PyTorch versions.

By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to
use its fallback manual attention implementation, which works correctly
with PyTorch 2.5.0 during the tracing process.

Refs: https://github.com/ggerganov/whisper.cpp/issues/2783

* coreml: fix audio shape in whisper decoder conversion

This commit fixes the audio shape in the whisper decoder conversion
script.

The motivation for this is that the  audio shape was incorrect and
was causing the conversion to fail.

* coreml : set -e in generate-coreml-interface.sh

The commit sets the -e flag in the generate-coreml-interface.sh script
to make sure the script fails if any command fails.

* coreml : update generated encoder/decoder interfaces

This commit updates the generated encoder/decoder interfaces for the
whisper model which is the result of running the
generate-coreml-interface.sh script.
2025-04-01 18:01:23 +02:00

329 lines
13 KiB
Python

import argparse
import torch
import torch.nn.functional as F
import coremltools as ct
from torch import Tensor
from torch import nn
from typing import Dict
from typing import Optional
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
from coremltools.models.neural_network.quantization_utils import quantize_weights
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
# The Whisper implementation expects a specific behavior from
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
# implementation instead, which is more stable across different PyTorch versions
# (2.5.0 required by coremltools vs newer versions).
import whisper.model
whisper.model.MultiHeadAttention.use_sdpa = False
# Use for changing dim of input in encoder and decoder embeddings
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""
Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_attention = all(substr in k for substr in ['attn', '.weight'])
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix + 'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix + 'weight']
return state_dict
class LayerNormANE(LayerNormANEBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)
class MultiHeadAttentionANE(MultiHeadAttention):
def __init__(self, n_state: int, n_head: int):
super().__init__(n_state, n_head)
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
def forward(self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv, qk = self.qkv_attention_ane(q, k, v, mask)
return self.out(wv), qk
def qkv_attention_ane(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
_, dim, _, seqlen = q.size()
dim_per_head = dim // self.n_head
scale = float(dim_per_head)**-0.5
q = q * scale
mh_q = q.split(dim_per_head, dim=1)
mh_k = k.transpose(1,3).split(dim_per_head, dim=3)
mh_v = v.split(dim_per_head, dim=1)
mh_qk = [
torch.einsum('bchq,bkhc->bkhq', [qi, ki])
for qi, ki in zip(mh_q, mh_k)
] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
if mask is not None:
for head_idx in range(self.n_head):
mh_qk[head_idx] = mh_qk[head_idx] + mask[:, :seqlen, :, :seqlen]
attn_weights = [aw.softmax(dim=1) for aw in mh_qk] # (batch_size, max_seq_length, 1, max_seq_length) * n_heads
attn = [torch.einsum('bkhq,bchk->bchq', wi, vi) for wi, vi in zip(attn_weights, mh_v)] # (batch_size, dim_per_head, 1, max_seq_length) * n_heads
attn = torch.cat(attn, dim=1) # (batch_size, dim, 1, max_seq_length)
return attn, torch.cat(mh_qk, dim=1).float().detach()
class ResidualAttentionBlockANE(ResidualAttentionBlock):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__(n_state, n_head, cross_attention)
self.attn = MultiHeadAttentionANE(n_state, n_head)
self.attn_ln = LayerNormANE(n_state)
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Conv2d(n_state, n_mlp, kernel_size=1),
nn.GELU(),
nn.Conv2d(n_mlp, n_state, kernel_size=1)
)
self.mlp_ln = LayerNormANE(n_state)
class AudioEncoderANE(AudioEncoder):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
self.blocks = nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNormANE(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
assert x.shape[1:] == self.positional_embedding.shape[::-1], "incorrect audio shape"
# Add positional embedding and add dummy dim for ANE
x = (x + self.positional_embedding.transpose(0,1)).to(x.dtype).unsqueeze(2)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
x = x.squeeze(2).transpose(1, 2)
return x
class TextDecoderANE(TextDecoder):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
self.blocks= nn.ModuleList(
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln= LayerNormANE(n_state)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[3] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
# Reformat for ANE
mask = self.mask[None, None, :, :].permute(0,3,1,2)
x = x.transpose(1,2).unsqueeze(2)
for block in self.blocks:
x = block(x, xa, mask=mask, kv_cache=kv_cache)
x = self.ln(x)
# Reformat back from ANE
x = x.permute(0,2,3,1).squeeze(0)
# ANE can only load tensors with dim size of at most 16,384 - whisper uses 51,864 (en) or 51,865 (multi-lang) tokens so we need to compute in chunks
if self.token_embedding.weight.shape[0] >= 51865:
# split in 11 chunks - 4715 each
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//11, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
else:
# split in 12 chunks - 4322 each
assert(self.token_embedding.weight.shape[0] == 51864)
splits = self.token_embedding.weight.split(self.token_embedding.weight.shape[0]//12, dim=0)
logits = torch.cat([torch.einsum('bid,jd->bij', x, split) for split in splits]).view(*x.shape[:2], -1)
return logits
class WhisperANE(Whisper):
def __init__(self, dims: ModelDimensions):
super().__init__(dims)
self.encoder = AudioEncoderANE(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoderANE(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[3] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=3).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttentionANE):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
def convert_encoder(hparams, model, quantize=False):
model.eval()
input_shape = (1, hparams.n_mels, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)
model = ct.convert(
traced_model,
convert_to="neuralnetwork",
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
outputs=[ct.TensorType(name="output")],
compute_units=ct.ComputeUnit.ALL
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
def convert_decoder(hparams, model, quantize=False):
model.eval()
tokens_shape = (1, 1)
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
audio_data = torch.randn(audio_shape)
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
traced_model = torch.jit.trace(model, (token_data, audio_data))
model = ct.convert(
traced_model,
convert_to="neuralnetwork",
inputs=[
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
ct.TensorType(name="audio_data", shape=audio_shape)
]
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()
hparams = whisper.dims
print(hparams)
if args.optimize_ane:
whisperANE = WhisperANE(hparams).eval()
whisperANE.load_state_dict(whisper.state_dict())
encoder = whisperANE.encoder
decoder = whisperANE.decoder
else:
encoder = whisper.encoder
decoder = whisper.decoder
# Convert encoder
encoder = convert_encoder(hparams, encoder, quantize=args.quantize)
encoder.save(f"models/coreml-encoder-{args.model}.mlpackage")
if args.encoder_only is False:
# Convert decoder
decoder = convert_decoder(hparams, decoder, quantize=args.quantize)
decoder.save(f"models/coreml-decoder-{args.model}.mlpackage")
print("done converting")