mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-24 19:29:18 +02:00
minor : improve C++ and Python style (#768)
* use some STL functions * use self.field than setattr, use pathlib.Path * recover some format * const some iter * Keep the original * 2 space
This commit is contained in:
parent
4d89ee2e59
commit
94aa56f19e
@ -23,6 +23,7 @@ import json
|
|||||||
import code
|
import code
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers import WhisperForConditionalGeneration
|
from transformers import WhisperForConditionalGeneration
|
||||||
|
|
||||||
@ -75,16 +76,13 @@ if len(sys.argv) < 4:
|
|||||||
print("Usage: convert-h5-to-ggml.py dir_model path-to-whisper-repo dir-output [use-f32]\n")
|
print("Usage: convert-h5-to-ggml.py dir_model path-to-whisper-repo dir-output [use-f32]\n")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
dir_model = sys.argv[1]
|
dir_model = Path(sys.argv[1])
|
||||||
dir_whisper = sys.argv[2]
|
dir_whisper = Path(sys.argv[2])
|
||||||
dir_out = sys.argv[3]
|
dir_out = Path(sys.argv[3])
|
||||||
|
|
||||||
with open(dir_model + "/vocab.json", "r", encoding="utf8") as f:
|
encoder = json.load((dir_model / "vocab.json").open("r", encoding="utf8"))
|
||||||
encoder = json.load(f)
|
encoder_added = json.load((dir_model / "added_tokens.json").open( "r", encoding="utf8"))
|
||||||
with open(dir_model + "/added_tokens.json", "r", encoding="utf8") as f:
|
hparams = json.load((dir_model / "config.json").open("r", encoding="utf8") )
|
||||||
encoder_added = json.load(f)
|
|
||||||
with open(dir_model + "/config.json", "r", encoding="utf8") as f:
|
|
||||||
hparams = json.load(f)
|
|
||||||
|
|
||||||
model = WhisperForConditionalGeneration.from_pretrained(dir_model)
|
model = WhisperForConditionalGeneration.from_pretrained(dir_model)
|
||||||
|
|
||||||
@ -96,16 +94,15 @@ with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as
|
|||||||
|
|
||||||
dir_tokenizer = dir_model
|
dir_tokenizer = dir_model
|
||||||
|
|
||||||
fname_out = dir_out + "/ggml-model.bin"
|
fname_out = dir_out / "ggml-model.bin"
|
||||||
|
|
||||||
with open(dir_tokenizer + "/vocab.json", "r", encoding="utf8") as f:
|
tokens = json.load(open(dir_tokenizer / "vocab.json", "r", encoding="utf8"))
|
||||||
tokens = json.load(f)
|
|
||||||
|
|
||||||
# use 16-bit or 32-bit floats
|
# use 16-bit or 32-bit floats
|
||||||
use_f16 = True
|
use_f16 = True
|
||||||
if len(sys.argv) > 4:
|
if len(sys.argv) > 4:
|
||||||
use_f16 = False
|
use_f16 = False
|
||||||
fname_out = dir_out + "/ggml-model-f32.bin"
|
fname_out = dir_out / "ggml-model-f32.bin"
|
||||||
|
|
||||||
fout = open(fname_out, "wb")
|
fout = open(fname_out, "wb")
|
||||||
|
|
||||||
@ -171,10 +168,9 @@ for name in list_vars.keys():
|
|||||||
data = data.astype(np.float16)
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
# reshape conv bias from [n] to [n, 1]
|
# reshape conv bias from [n] to [n, 1]
|
||||||
if name == "encoder.conv1.bias" or \
|
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
|
||||||
name == "encoder.conv2.bias":
|
|
||||||
data = data.reshape(data.shape[0], 1)
|
data = data.reshape(data.shape[0], 1)
|
||||||
print(" Reshaped variable: " + name + " to shape: ", data.shape)
|
print(" Reshaped variable: " , name , " to shape: ", data.shape)
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
n_dims = len(data.shape)
|
||||||
print(name, n_dims, data.shape)
|
print(name, n_dims, data.shape)
|
||||||
@ -182,7 +178,7 @@ for name in list_vars.keys():
|
|||||||
# looks like the whisper models are in f16 by default
|
# looks like the whisper models are in f16 by default
|
||||||
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
|
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
|
||||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||||
ftype = 1;
|
ftype = 1
|
||||||
if use_f16:
|
if use_f16:
|
||||||
if n_dims < 2 or \
|
if n_dims < 2 or \
|
||||||
name == "encoder.conv1.bias" or \
|
name == "encoder.conv1.bias" or \
|
||||||
@ -197,16 +193,16 @@ for name in list_vars.keys():
|
|||||||
ftype = 0
|
ftype = 0
|
||||||
|
|
||||||
# header
|
# header
|
||||||
str = name.encode('utf-8')
|
str_ = name.encode('utf-8')
|
||||||
fout.write(struct.pack("iii", n_dims, len(str), ftype))
|
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
|
||||||
for i in range(n_dims):
|
for i in range(n_dims):
|
||||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||||
fout.write(str);
|
fout.write(str_)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data.tofile(fout)
|
data.tofile(fout)
|
||||||
|
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
print("Done. Output file: " + fname_out)
|
print("Done. Output file: " , fname_out)
|
||||||
print("")
|
print("")
|
||||||
|
@ -40,7 +40,7 @@ import code
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import base64
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
#from transformers import GPTJForCausalLM
|
#from transformers import GPTJForCausalLM
|
||||||
#from transformers import GPT2TokenizerFast
|
#from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
@ -194,17 +194,17 @@ if len(sys.argv) < 4:
|
|||||||
print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
|
print("Usage: convert-pt-to-ggml.py model.pt path-to-whisper-repo dir-output [use-f32]\n")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
fname_inp = sys.argv[1]
|
fname_inp = Path(sys.argv[1])
|
||||||
dir_whisper = sys.argv[2]
|
dir_whisper = Path(sys.argv[2])
|
||||||
dir_out = sys.argv[3]
|
dir_out = Path(sys.argv[3])
|
||||||
|
|
||||||
# try to load PyTorch binary data
|
# try to load PyTorch binary data
|
||||||
try:
|
try:
|
||||||
model_bytes = open(fname_inp, "rb").read()
|
model_bytes = open(fname_inp, "rb").read()
|
||||||
with io.BytesIO(model_bytes) as fp:
|
with io.BytesIO(model_bytes) as fp:
|
||||||
checkpoint = torch.load(fp, map_location="cpu")
|
checkpoint = torch.load(fp, map_location="cpu")
|
||||||
except:
|
except Exception:
|
||||||
print("Error: failed to load PyTorch model file: %s" % fname_inp)
|
print("Error: failed to load PyTorch model file:" , fname_inp)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
hparams = checkpoint["dims"]
|
hparams = checkpoint["dims"]
|
||||||
@ -218,17 +218,17 @@ list_vars = checkpoint["model_state_dict"]
|
|||||||
|
|
||||||
# load mel filters
|
# load mel filters
|
||||||
n_mels = hparams["n_mels"]
|
n_mels = hparams["n_mels"]
|
||||||
with np.load(os.path.join(dir_whisper, "whisper/assets", "mel_filters.npz")) as f:
|
with np.load(dir_whisper / "whisper" / "assets" / "mel_filters.npz") as f:
|
||||||
filters = torch.from_numpy(f[f"mel_{n_mels}"])
|
filters = torch.from_numpy(f[f"mel_{n_mels}"])
|
||||||
#print (filters)
|
#print (filters)
|
||||||
|
|
||||||
#code.interact(local=locals())
|
#code.interact(local=locals())
|
||||||
|
|
||||||
multilingual = hparams["n_vocab"] == 51865
|
multilingual = hparams["n_vocab"] == 51865
|
||||||
tokenizer = os.path.join(dir_whisper, "whisper/assets", multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
||||||
|
|
||||||
# output in the same directory as the model
|
# output in the same directory as the model
|
||||||
fname_out = dir_out + "/ggml-model.bin"
|
fname_out = dir_out / "ggml-model.bin"
|
||||||
|
|
||||||
with open(tokenizer, "rb") as f:
|
with open(tokenizer, "rb") as f:
|
||||||
contents = f.read()
|
contents = f.read()
|
||||||
@ -238,9 +238,9 @@ with open(tokenizer, "rb") as f:
|
|||||||
use_f16 = True
|
use_f16 = True
|
||||||
if len(sys.argv) > 4:
|
if len(sys.argv) > 4:
|
||||||
use_f16 = False
|
use_f16 = False
|
||||||
fname_out = dir_out + "/ggml-model-f32.bin"
|
fname_out = dir_out / "ggml-model-f32.bin"
|
||||||
|
|
||||||
fout = open(fname_out, "wb")
|
fout = fname_out.open("wb")
|
||||||
|
|
||||||
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
|
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
|
||||||
fout.write(struct.pack("i", hparams["n_vocab"]))
|
fout.write(struct.pack("i", hparams["n_vocab"]))
|
||||||
@ -273,20 +273,19 @@ for key in tokens:
|
|||||||
|
|
||||||
for name in list_vars.keys():
|
for name in list_vars.keys():
|
||||||
data = list_vars[name].squeeze().numpy()
|
data = list_vars[name].squeeze().numpy()
|
||||||
print("Processing variable: " + name + " with shape: ", data.shape)
|
print("Processing variable: " , name , " with shape: ", data.shape)
|
||||||
|
|
||||||
# reshape conv bias from [n] to [n, 1]
|
# reshape conv bias from [n] to [n, 1]
|
||||||
if name == "encoder.conv1.bias" or \
|
if name in ["encoder.conv1.bias", "encoder.conv2.bias"]:
|
||||||
name == "encoder.conv2.bias":
|
|
||||||
data = data.reshape(data.shape[0], 1)
|
data = data.reshape(data.shape[0], 1)
|
||||||
print(" Reshaped variable: " + name + " to shape: ", data.shape)
|
print(f" Reshaped variable: {name} to shape: ", data.shape)
|
||||||
|
|
||||||
n_dims = len(data.shape);
|
n_dims = len(data.shape)
|
||||||
|
|
||||||
# looks like the whisper models are in f16 by default
|
# looks like the whisper models are in f16 by default
|
||||||
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
|
# so we need to convert the small tensors to f32 until we fully support f16 in ggml
|
||||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||||
ftype = 1;
|
ftype = 1
|
||||||
if use_f16:
|
if use_f16:
|
||||||
if n_dims < 2 or \
|
if n_dims < 2 or \
|
||||||
name == "encoder.conv1.bias" or \
|
name == "encoder.conv1.bias" or \
|
||||||
@ -307,16 +306,16 @@ for name in list_vars.keys():
|
|||||||
# data = data.transpose()
|
# data = data.transpose()
|
||||||
|
|
||||||
# header
|
# header
|
||||||
str = name.encode('utf-8')
|
str_ = name.encode('utf-8')
|
||||||
fout.write(struct.pack("iii", n_dims, len(str), ftype))
|
fout.write(struct.pack("iii", n_dims, len(str_), ftype))
|
||||||
for i in range(n_dims):
|
for i in range(n_dims):
|
||||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||||
fout.write(str);
|
fout.write(str_)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
data.tofile(fout)
|
data.tofile(fout)
|
||||||
|
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
print("Done. Output file: " + fname_out)
|
print("Done. Output file: " , fname_out)
|
||||||
print("")
|
print("")
|
||||||
|
@ -20,7 +20,7 @@ def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
|||||||
"""
|
"""
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
is_attention = all(substr in k for substr in ['attn', '.weight'])
|
is_attention = all(substr in k for substr in ['attn', '.weight'])
|
||||||
is_mlp = any([k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']])
|
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
|
||||||
|
|
||||||
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
|
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
|
||||||
state_dict[k] = state_dict[k][:, :, None, None]
|
state_dict[k] = state_dict[k][:, :, None, None]
|
||||||
@ -42,11 +42,10 @@ class LayerNormANE(LayerNormANEBase):
|
|||||||
class MultiHeadAttentionANE(MultiHeadAttention):
|
class MultiHeadAttentionANE(MultiHeadAttention):
|
||||||
def __init__(self, n_state: int, n_head: int):
|
def __init__(self, n_state: int, n_head: int):
|
||||||
super().__init__(n_state, n_head)
|
super().__init__(n_state, n_head)
|
||||||
|
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
|
||||||
setattr(self, 'query', nn.Conv2d(n_state, n_state, kernel_size=1))
|
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
|
||||||
setattr(self, 'key', nn.Conv2d(n_state, n_state, kernel_size=1, bias=False))
|
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
|
||||||
setattr(self, 'value', nn.Conv2d(n_state, n_state, kernel_size=1))
|
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
|
||||||
setattr(self, 'out', nn.Conv2d(n_state, n_state, kernel_size=1))
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -104,30 +103,28 @@ class MultiHeadAttentionANE(MultiHeadAttention):
|
|||||||
class ResidualAttentionBlockANE(ResidualAttentionBlock):
|
class ResidualAttentionBlockANE(ResidualAttentionBlock):
|
||||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||||
super().__init__(n_state, n_head, cross_attention)
|
super().__init__(n_state, n_head, cross_attention)
|
||||||
|
self.attn = MultiHeadAttentionANE(n_state, n_head)
|
||||||
setattr(self, 'attn', MultiHeadAttentionANE(n_state, n_head))
|
self.attn_ln = LayerNormANE(n_state)
|
||||||
setattr(self, 'attn_ln', LayerNormANE(n_state))
|
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
|
||||||
|
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
|
||||||
setattr(self, 'cross_attn', MultiHeadAttentionANE(n_state, n_head) if cross_attention else None)
|
|
||||||
setattr(self, 'cross_attn_ln', LayerNormANE(n_state) if cross_attention else None)
|
|
||||||
|
|
||||||
n_mlp = n_state * 4
|
n_mlp = n_state * 4
|
||||||
setattr(self, 'mlp', nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
nn.Conv2d(n_state, n_mlp, kernel_size=1),
|
nn.Conv2d(n_state, n_mlp, kernel_size=1),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Conv2d(n_mlp, n_state, kernel_size=1)
|
nn.Conv2d(n_mlp, n_state, kernel_size=1)
|
||||||
))
|
)
|
||||||
setattr(self, 'mlp_ln', LayerNormANE(n_state))
|
self.mlp_ln = LayerNormANE(n_state)
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoderANE(AudioEncoder):
|
class AudioEncoderANE(AudioEncoder):
|
||||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||||
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
||||||
|
|
||||||
setattr(self, 'blocks', nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
|
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
|
||||||
))
|
)
|
||||||
setattr(self, 'ln_post', LayerNormANE(n_state))
|
self.ln_post = LayerNormANE(n_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
"""
|
"""
|
||||||
@ -168,10 +165,10 @@ class TextDecoderANE(TextDecoder):
|
|||||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||||
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
||||||
|
|
||||||
setattr(self, 'blocks', nn.ModuleList(
|
self.blocks= nn.ModuleList(
|
||||||
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||||
))
|
)
|
||||||
setattr(self, 'ln', LayerNormANE(n_state))
|
self.ln= LayerNormANE(n_state)
|
||||||
|
|
||||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
"""
|
"""
|
||||||
@ -213,20 +210,20 @@ class WhisperANE(Whisper):
|
|||||||
def __init__(self, dims: ModelDimensions):
|
def __init__(self, dims: ModelDimensions):
|
||||||
super().__init__(dims)
|
super().__init__(dims)
|
||||||
|
|
||||||
setattr(self, 'encoder', AudioEncoderANE(
|
self.encoder = AudioEncoderANE(
|
||||||
self.dims.n_mels,
|
self.dims.n_mels,
|
||||||
self.dims.n_audio_ctx,
|
self.dims.n_audio_ctx,
|
||||||
self.dims.n_audio_state,
|
self.dims.n_audio_state,
|
||||||
self.dims.n_audio_head,
|
self.dims.n_audio_head,
|
||||||
self.dims.n_audio_layer,
|
self.dims.n_audio_layer,
|
||||||
))
|
)
|
||||||
setattr(self, 'decoder', TextDecoderANE(
|
self.decoder = TextDecoderANE(
|
||||||
self.dims.n_vocab,
|
self.dims.n_vocab,
|
||||||
self.dims.n_text_ctx,
|
self.dims.n_text_ctx,
|
||||||
self.dims.n_text_state,
|
self.dims.n_text_state,
|
||||||
self.dims.n_text_head,
|
self.dims.n_text_head,
|
||||||
self.dims.n_text_layer,
|
self.dims.n_text_layer,
|
||||||
))
|
)
|
||||||
|
|
||||||
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
||||||
|
|
||||||
|
79
whisper.cpp
79
whisper.cpp
@ -2356,11 +2356,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
|
|||||||
sum += fft_out[k] * filters.data[j * n_fft + k];
|
sum += fft_out[k] * filters.data[j * n_fft + k];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sum < 1e-10) {
|
sum = log10(std::max(sum, 1e-10));
|
||||||
sum = 1e-10;
|
|
||||||
}
|
|
||||||
|
|
||||||
sum = log10(sum);
|
|
||||||
|
|
||||||
mel.data[j * mel.n_len + i] = sum;
|
mel.data[j * mel.n_len + i] = sum;
|
||||||
}
|
}
|
||||||
@ -2602,7 +2598,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
|
||||||
whisper_model_loader loader = {};
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
fprintf(stderr, "%s: loading model from '%s'\n", __func__, path_model);
|
||||||
|
|
||||||
@ -2612,22 +2607,27 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
loader.context = &fin;
|
whisper_model_loader loader = {
|
||||||
|
.context = &fin,
|
||||||
|
|
||||||
loader.read = [](void * ctx, void * output, size_t read_size) {
|
.read =
|
||||||
std::ifstream * fin = (std::ifstream*)ctx;
|
[](void *ctx, void *output, size_t read_size) {
|
||||||
fin->read((char *)output, read_size);
|
std::ifstream *fin = (std::ifstream *)ctx;
|
||||||
return read_size;
|
fin->read((char *)output, read_size);
|
||||||
};
|
return read_size;
|
||||||
|
},
|
||||||
|
|
||||||
loader.eof = [](void * ctx) {
|
.eof =
|
||||||
std::ifstream * fin = (std::ifstream*)ctx;
|
[](void *ctx) {
|
||||||
return fin->eof();
|
std::ifstream *fin = (std::ifstream *)ctx;
|
||||||
};
|
return fin->eof();
|
||||||
|
},
|
||||||
|
|
||||||
loader.close = [](void * ctx) {
|
.close =
|
||||||
std::ifstream * fin = (std::ifstream*)ctx;
|
[](void *ctx) {
|
||||||
fin->close();
|
std::ifstream *fin = (std::ifstream *)ctx;
|
||||||
|
fin->close();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto ctx = whisper_init_no_state(&loader);
|
auto ctx = whisper_init_no_state(&loader);
|
||||||
@ -2647,30 +2647,34 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t
|
|||||||
};
|
};
|
||||||
|
|
||||||
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
|
||||||
whisper_model_loader loader = {};
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: loading model from buffer\n", __func__);
|
fprintf(stderr, "%s: loading model from buffer\n", __func__);
|
||||||
|
|
||||||
loader.context = &ctx;
|
whisper_model_loader loader = {
|
||||||
|
.context = &ctx,
|
||||||
|
|
||||||
loader.read = [](void * ctx, void * output, size_t read_size) {
|
.read =
|
||||||
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
[](void *ctx, void *output, size_t read_size) {
|
||||||
|
buf_context *buf = reinterpret_cast<buf_context *>(ctx);
|
||||||
|
|
||||||
size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
|
size_t size_to_copy = buf->current_offset + read_size < buf->size
|
||||||
|
? read_size
|
||||||
|
: buf->size - buf->current_offset;
|
||||||
|
|
||||||
memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
|
memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
|
||||||
buf->current_offset += size_to_copy;
|
buf->current_offset += size_to_copy;
|
||||||
|
|
||||||
return size_to_copy;
|
return size_to_copy;
|
||||||
};
|
},
|
||||||
|
|
||||||
loader.eof = [](void * ctx) {
|
.eof =
|
||||||
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
|
[](void *ctx) {
|
||||||
|
buf_context *buf = reinterpret_cast<buf_context *>(ctx);
|
||||||
|
|
||||||
return buf->current_offset >= buf->size;
|
return buf->current_offset >= buf->size;
|
||||||
};
|
},
|
||||||
|
|
||||||
loader.close = [](void * /*ctx*/) { };
|
.close = [](void * /*ctx*/) {}};
|
||||||
|
|
||||||
return whisper_init_no_state(&loader);
|
return whisper_init_no_state(&loader);
|
||||||
}
|
}
|
||||||
@ -2909,7 +2913,6 @@ int whisper_lang_id(const char * lang) {
|
|||||||
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return g_lang.at(lang).first;
|
return g_lang.at(lang).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3303,15 +3306,15 @@ static void whisper_exp_compute_token_level_timestamps(
|
|||||||
|
|
||||||
// trim from start (in place)
|
// trim from start (in place)
|
||||||
static inline void ltrim(std::string &s) {
|
static inline void ltrim(std::string &s) {
|
||||||
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
|
s.erase(s.begin(), std::find_if_not(s.begin(), s.end(), [](unsigned char ch) {
|
||||||
return !std::isspace(ch);
|
return std::isspace(ch);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// trim from end (in place)
|
// trim from end (in place)
|
||||||
static inline void rtrim(std::string &s) {
|
static inline void rtrim(std::string &s) {
|
||||||
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
|
s.erase(std::find_if_not(s.rbegin(), s.rend(), [](unsigned char ch) {
|
||||||
return !std::isspace(ch);
|
return std::isspace(ch);
|
||||||
}).base(), s.end());
|
}).base(), s.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user