Compare commits

..

44 Commits

Author SHA1 Message Date
3ac0558009 ios : update SPM package 2023-09-15 12:13:33 +03:00
a1664574fe bench : variable n_past 2023-09-14 22:41:41 +03:00
bfcb2a2ab9 metal : remove the "concurrent" flag 2023-09-14 18:04:42 +03:00
0d5e4cdc36 whisper : clean-up ggml_mul_mat_pad 2023-09-14 17:28:13 +03:00
2b4160af29 whisper : add description of ggml_mul_mat_pad 2023-09-14 15:37:10 +03:00
f36554382a whisper : add comment for disabling mul-mat padding 2023-09-14 15:25:19 +03:00
c46167f8c5 bench : fix uninitialized vars 2023-09-14 15:19:27 +03:00
af947cb72e whisper : add ggml_mul_mat_pad 2023-09-14 15:16:22 +03:00
e81c67a125 bench : start benching the decoder 2023-09-14 10:06:14 +03:00
f408c64564 bench : fix timings by running a pre-heat 2023-09-13 23:03:25 +03:00
d863f725a1 coreml : add code to toggle Core ML config (CPU, ANE, GPU) 2023-09-13 22:51:10 +03:00
d37f56e7a9 ios : update submodule 2023-09-13 21:31:29 +03:00
23277d21ce readme : add Metal info 2023-09-13 20:54:03 +03:00
ecb23fb1eb metal : sync latest llama.cpp kernels 2023-09-13 20:44:05 +03:00
8e8daa8451 metal : speed-up KQ multiplication 2023-09-13 19:59:16 +03:00
16db4da3f1 swiftui : fix build 2023-09-13 19:49:11 +03:00
257d7942af ios : add Metal support 2023-09-13 19:45:12 +03:00
181bb8cb28 objc : fix build (no Metal yet) 2023-09-13 18:54:41 +03:00
796f84cd95 whisper : add <functional> header 2023-09-13 13:35:42 +03:00
77f4bf49c8 cmake : update to support Metal build 2023-09-13 13:34:51 +03:00
b6f09669a2 whisper : factor out alloc init in a function 2023-09-13 12:51:52 +03:00
254b687239 whisper : add whisper_allocr to wrap ggml_allocr 2023-09-13 11:58:19 +03:00
b19888cfb4 ggml-alloc : try to make CI happy by reducing vram to 128GB 2023-09-13 11:57:46 +03:00
905c944143 ggml : use simpler ggml_bytes() implementation 2023-09-13 11:39:09 +03:00
3074a7ff14 whisper : offload the Encoder to Metal 2023-09-13 00:09:44 +03:00
ec9a7db74c whisper : remove ggml_repeat in the encoder 2023-09-12 20:34:32 +03:00
cd476375b4 metal : run "cross" step on the GPU 2023-09-12 20:11:13 +03:00
9fdd415367 ggml : fix ggml_nbytes (probably temp solution) 2023-09-12 20:10:53 +03:00
79a88057bd metal : add multi-decoder support 2023-09-12 19:33:29 +03:00
fbc9ddc582 metal : decoder works on GPU! 2023-09-12 19:23:30 +03:00
3b9979a373 ci : try to debug vmem issue 2023-09-12 14:08:48 +03:00
de94c783ee Merge branch 'master' into metal-and-alloc 2023-09-12 14:02:43 +03:00
d3b2dd4955 whisper : initial Metal version 2023-09-11 16:23:31 +03:00
4845b9ed09 whisper.android : try to fix build 2023-09-11 15:19:21 +03:00
2770d46ef5 whisper : refactor ggml-alloc init 2023-09-11 15:04:33 +03:00
4d9acc60c3 ci : see if this is causing the crash 2023-09-11 14:42:25 +03:00
06d1d2836b extra : update sync-ggml.sh script to also sync ggml-alloc 2023-09-10 22:45:38 +03:00
9a78b72246 ios : update submodule 2023-09-10 22:36:50 +03:00
794e8fe0ea build : fix ggml-alloc 2023-09-10 22:19:39 +03:00
fa672b46e6 whisper : CoreML support ggml-alloc 2023-09-10 21:57:04 +03:00
af6f67b251 whisper : ggml-alloc is now supported 2023-09-10 20:09:17 +03:00
bed5ad69dd whisper : allocate encoder and decoder using ggml-alloc 2023-09-10 19:50:34 +03:00
949ab6328d whisper : factor out graph builds 2023-09-10 19:23:06 +03:00
fbc3f8033e metal : init 2023-09-10 18:38:34 +03:00
35 changed files with 2222 additions and 6675 deletions

View File

@ -1,28 +0,0 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=11.7.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
FROM ${BASE_CUDA_DEV_CONTAINER} as build
# Unless otherwise specified, we make a fat build.
ARG CUDA_DOCKER_ARCH=all
RUN apt-get update && \
apt-get install -y build-essential git cmake
WORKDIR /app
COPY . .
# Set nvcc architecture
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV WHISPER_CUBLAS=1
RUN make
ENTRYPOINT ["/app/main"]

2
.gitignore vendored
View File

@ -46,5 +46,3 @@ models/*.mlpackage
bindings/java/.gradle/
bindings/java/.idea/
.idea/
benchmark_results.csv

View File

@ -117,7 +117,7 @@ if (APPLE)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
else()
message(FATAL_ERROR "Accelerate framework not found")
message(WARNING "Accelerate framework not found")
endif()
endif()
@ -140,7 +140,7 @@ if (APPLE)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
endif()
else()
message(FATAL_ERROR "Metal framework not found")
message(WARNING "Metal framework not found")
endif()
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
@ -158,7 +158,7 @@ if (APPLE)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
else()
message(FATAL_ERROR "CoreML framework not found")
message(WARNING "CoreML framework not found")
endif()
if (WHISPER_COREML_ALLOW_FALLBACK)
@ -181,13 +181,13 @@ if (WHISPER_BLAS)
include_directories($ENV{OPENBLAS_PATH}/include)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else ()
message(FATAL_ERROR "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
message(WARNING "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
endif ()
else ()
set(BLA_STATIC 1)
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
# set(BLA_PREFER_PKGCONFIG 1)
set(BLA_SIZEOF_INTEGER 8)
set(BLA_PREFER_PKGCONFIG 1)
find_package(BLAS)
if(BLAS_FOUND)
@ -198,7 +198,7 @@ if (WHISPER_BLAS)
include_directories(${BLAS_INCLUDE_DIRS})
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else()
message(FATAL_ERROR "BLAS library was not found")
message(WARNING "BLAS library was not found")
endif()
endif ()
endif ()
@ -224,7 +224,7 @@ if (WHISPER_CUBLAS)
endif()
else()
message(FATAL_ERROR "cuBLAS not found")
message(WARNING "cuBLAS not found")
endif()
endif()
@ -255,7 +255,7 @@ if (WHISPER_HIPBLAS)
endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
else()
message(FATAL_ERROR "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
endif()
@ -270,7 +270,7 @@ if (WHISPER_CLBLAST)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
else()
message(FATAL_ERROR "CLBlast not found")
message(WARNING "CLBlast not found")
endif()
endif()

View File

@ -186,7 +186,6 @@ ifndef WHISPER_NO_METAL
ifeq ($(UNAME_S),Darwin)
WHISPER_METAL := 1
CFLAGS += -DGGML_USE_METAL
CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
endif

View File

@ -50,7 +50,7 @@ You can also easily make your own offline voice assistant application: [command]
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
On Apple Silicon, the inference runs fully on the GPU via Metal:
On Apply Silicon, the inference runs fully on the GPU via Metal:
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
@ -113,37 +113,30 @@ options:
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [2 ] number of best candidates to keep
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-olrc, --output-lrc [false ] output result in a lrc file
-owts, --output-words [false ] output script for generating karaoke video
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-nt, --no-timestamps [true ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of token
bash ./models/download-ggml-model.sh base.en
@ -709,19 +702,6 @@ took to execute it. The results are summarized in the following Github issue:
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py).
You can run it with the following command, by default it will run against any standard model in the models folder.
```bash
python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2
```
It is written in python with the intention of being easy to modify and extend for your benchmarking use case.
It outputs a csv file with the results of the benchmarking.
## ggml format
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
@ -743,7 +723,7 @@ in [models](models).
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [X] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Java:

View File

@ -118,11 +118,6 @@ func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}
// Set audio encoder context
func (p *Params) SetAudioCtx(n int) {
p.audio_ctx = C.int(n)
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS
@ -146,7 +141,6 @@ func (p *Params) String() string {
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx)
if p.translate {
str += " translate"
}

View File

@ -82,7 +82,7 @@ func (context *context) SetSpeedup(v bool) {
}
func (context *context) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v)
context.params.SetSplitOnWord(v)
}
// Set number of threads to use
@ -125,11 +125,6 @@ func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// Set audio encoder context
func (context *context) SetAudioCtx(n uint) {
context.params.SetAudioCtx(int(n))
}
// ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings()

View File

@ -48,7 +48,6 @@ type Context interface {
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
SetAudioCtx(uint) // Set audio encoder context
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the

View File

@ -1,8 +1,6 @@
Makefile
ggml.c
ggml.h
ggml-alloc.c
ggml-alloc.h
whisper.bundle
whisper.cpp
whisper.h

View File

@ -3,8 +3,6 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")

View File

@ -1,7 +1,6 @@
#include "whisper.h"
#include <cstdio>
#include <cstring>
#include <string>
#include <thread>

View File

@ -7,8 +7,6 @@
#include <vector>
#include <random>
#include <thread>
#include <ctime>
#include <fstream>
#define COMMON_SAMPLE_RATE 16000
@ -141,104 +139,6 @@ bool read_wav(
std::vector<std::vector<float>> & pcmf32s,
bool stereo);
// Write PCM data into WAV audio file
class wav_writer {
private:
std::ofstream file;
uint32_t dataSize = 0;
std::string wav_filename;
bool write_header(const uint32_t sample_rate,
const uint16_t bits_per_sample,
const uint16_t channels) {
file.write("RIFF", 4);
file.write("\0\0\0\0", 4); // Placeholder for file size
file.write("WAVE", 4);
file.write("fmt ", 4);
const uint32_t sub_chunk_size = 16;
const uint16_t audio_format = 1; // PCM format
const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8;
const uint16_t block_align = channels * bits_per_sample / 8;
file.write(reinterpret_cast<const char *>(&sub_chunk_size), 4);
file.write(reinterpret_cast<const char *>(&audio_format), 2);
file.write(reinterpret_cast<const char *>(&channels), 2);
file.write(reinterpret_cast<const char *>(&sample_rate), 4);
file.write(reinterpret_cast<const char *>(&byte_rate), 4);
file.write(reinterpret_cast<const char *>(&block_align), 2);
file.write(reinterpret_cast<const char *>(&bits_per_sample), 2);
file.write("data", 4);
file.write("\0\0\0\0", 4); // Placeholder for data size
return true;
}
// It is assumed that PCM data is normalized to a range from -1 to 1
bool write_audio(const float * data, size_t length) {
for (size_t i = 0; i < length; ++i) {
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
dataSize += sizeof(int16_t);
}
if (file.is_open()) {
file.seekp(4, std::ios::beg);
uint32_t fileSize = 36 + dataSize;
file.write(reinterpret_cast<char *>(&fileSize), 4);
file.seekp(40, std::ios::beg);
file.write(reinterpret_cast<char *>(&dataSize), 4);
file.seekp(0, std::ios::end);
}
return true;
}
bool open_wav(const std::string & filename) {
if (filename != wav_filename) {
if (file.is_open()) {
file.close();
}
}
if (!file.is_open()) {
file.open(filename, std::ios::binary);
wav_filename = filename;
dataSize = 0;
}
return file.is_open();
}
public:
bool open(const std::string & filename,
const uint32_t sample_rate,
const uint16_t bits_per_sample,
const uint16_t channels) {
if (open_wav(filename)) {
write_header(sample_rate, bits_per_sample, channels);
} else {
return false;
}
return true;
}
bool close() {
file.close();
return true;
}
bool write(const float * data, size_t length) {
return write_audio(data, length);
}
~wav_writer() {
if (file.is_open()) {
file.close();
}
}
};
// Apply a high-pass frequency filter to PCM audio
// Suppresses frequencies below cutoff Hz
void high_pass_filter(

View File

@ -83,7 +83,6 @@ struct whisper_params {
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_jsn_full = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
@ -152,7 +151,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
@ -208,7 +206,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
@ -514,12 +511,7 @@ bool output_score(struct whisper_context * ctx, const char * fname, const whispe
return true;
}
bool output_json(
struct whisper_context * ctx,
const char * fname,
const whisper_params & params,
std::vector<std::vector<float>> pcmf32s,
bool full) {
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
int indent = 0;
@ -536,7 +528,7 @@ bool output_json(
auto end_arr = [&](bool end) {
indent--;
doindent();
fout << (end ? "]\n" : "],\n");
fout << (end ? "]\n" : "},\n");
};
auto start_obj = [&](const char *name) {
@ -577,29 +569,12 @@ bool output_json(
end_value(end);
};
auto value_f = [&](const char *name, const float val, bool end) {
start_value(name);
fout << val;
end_value(end);
};
auto value_b = [&](const char *name, const bool val, bool end) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
};
auto times_o = [&](int64_t t0, int64_t t1, bool end) {
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str(), false);
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj(false);
start_obj("offsets");
value_i("from", t0 * 10, false);
value_i("to", t1 * 10, true);
end_obj(end);
};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
@ -645,26 +620,15 @@ bool output_json(
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj(nullptr);
times_o(t0, t1, false);
value_s("text", text, !params.diarize && !params.tinydiarize && !full);
if (full) {
start_arr("tokens");
const int n = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n; ++j) {
auto token = whisper_full_get_token_data(ctx, i, j);
start_obj(nullptr);
value_s("text", whisper_token_to_str(ctx, token.id), false);
if(token.t0 > -1 && token.t1 > -1) {
// If we have per-token timestamps, write them out
times_o(token.t0, token.t1, false);
}
value_i("id", token.id, false);
value_f("p", token.p, true);
end_obj(j == (n - 1));
}
end_arr(!params.diarize && !params.tinydiarize);
}
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str(), false);
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj(false);
start_obj("offsets");
value_i("from", t0 * 10, false);
value_i("to", t1 * 10, true);
end_obj(false);
value_s("text", text, !params.diarize && !params.tinydiarize);
if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
@ -948,7 +912,7 @@ int main(int argc, char ** argv) {
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
@ -980,9 +944,8 @@ int main(int argc, char ** argv) {
wparams.progress_callback_user_data = &user_data;
}
// examples for abort mechanism
// in examples below, we do not abort the processing, but we could if the flag is set to true
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
@ -994,17 +957,6 @@ int main(int argc, char ** argv) {
wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
@ -1048,7 +1000,7 @@ int main(int argc, char ** argv) {
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
}
// output to LRC file

View File

@ -39,20 +39,6 @@ brew install sdl2
make stream
```
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
`examples/steam` means your compiler cannot find them and it gives an error it cannot find the file.
```bash
whisper.cpp/examples/stream$ make stream
g++ stream.cpp -o stream
stream.cpp:6:10: fatal error: common/sdl.h: No such file or directory
6 | #include "common/sdl.h"
| ^~~~~~~~~~~~~~
compilation terminated.
make: *** [<builtin>: stream] Error 1
```
## Web version
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)

View File

@ -2,6 +2,7 @@
//
// A very quick-n-dirty implementation serving mainly as a proof of concept.
//
#include "common-sdl.h"
#include "common.h"
#include "whisper.h"
@ -13,7 +14,6 @@
#include <vector>
#include <fstream>
// 500 -> 00:05.000
// 6000 -> 01:00.000
std::string to_timestamp(int64_t t) {
@ -52,7 +52,6 @@ struct whisper_params {
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
bool save_audio = false; // save audio to wav file
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -83,7 +82,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@ -119,7 +117,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
fprintf(stderr, "\n");
}
@ -157,6 +154,7 @@ int main(int argc, char ** argv) {
audio.resume();
// whisper init
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
@ -214,28 +212,14 @@ int main(int argc, char ** argv) {
}
}
wav_writer wavWriter;
// save wav file
if (params.save_audio) {
// Get current date/time for filename
time_t now = time(0);
char buffer[80];
strftime(buffer, sizeof(buffer), "%Y%m%d%H%M%S", localtime(&now));
std::string filename = std::string(buffer) + ".wav";
wavWriter.open(filename, WHISPER_SAMPLE_RATE, 16, 1);
}
printf("[Start speaking]\n");
printf("[Start speaking]");
fflush(stdout);
auto t_last = std::chrono::high_resolution_clock::now();
auto t_last = std::chrono::high_resolution_clock::now();
const auto t_start = t_last;
// main audio loop
while (is_running) {
if (params.save_audio) {
wavWriter.write(pcmf32_new.data(), pcmf32_new.size());
}
// handle Ctrl + C
is_running = sdl_poll_events();
@ -387,7 +371,7 @@ int main(int argc, char ** argv) {
fout << std::endl;
}
if (use_vad) {
if (use_vad){
printf("\n");
printf("### Transcription %d END\n", n_iter);
}
@ -424,4 +408,4 @@ int main(int argc, char ** argv) {
whisper_free(ctx);
return 0;
}
}

View File

@ -2,12 +2,6 @@
Talk with an LLaMA AI in your terminal
*Latest perf as of 2 Nov 2023 using Whisper Medium + LLaMA v2 13B Q8_0 on M2 Ultra:*
https://github.com/ggerganov/whisper.cpp/assets/1991296/d97a3788-bf2a-4756-9a43-60c6b391649e
*Previous demo running on CPUs*
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
## Building
@ -25,7 +19,7 @@ brew install sdl2
make talk-llama
# Run it
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
@ -42,7 +36,7 @@ This feature is especially helpful for maintaining context in long conversations
Example usage:
```bash
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
## TTS

View File

@ -0,0 +1,474 @@
// Internal header to be included only by llama.cpp.
// Contains wrappers around OS interfaces.
#ifndef LLAMA_UTIL_H
#define LLAMA_UTIL_H
#include <cstdio>
#include <cstdint>
#include <cerrno>
#include <cstring>
#include <cstdarg>
#include <cstdlib>
#include <climits>
#include <string>
#include <vector>
#include <stdexcept>
#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
#if defined(_POSIX_MAPPED_FILES)
#include <sys/mman.h>
#endif
#if defined(_POSIX_MEMLOCK_RANGE)
#include <sys/resource.h>
#endif
#endif
#endif
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
#endif
#define LLAMA_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#ifdef __GNUC__
#ifdef __MINGW32__
__attribute__((format(gnu_printf, 1, 2)))
#else
__attribute__((format(printf, 1, 2)))
#endif
#endif
static std::string format(const char * fmt, ...) {
va_list ap, ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
LLAMA_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}
struct llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
}
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
size_t tell() const {
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
return (size_t) ret;
}
void seek(size_t offset, int whence) {
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
LLAMA_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, len, 1, fp);
if (ferror(fp)) {
throw std::runtime_error(format("read error: %s", strerror(errno)));
}
if (ret != 1) {
throw std::runtime_error(std::string("unexpectedly reached end of file"));
}
}
std::uint32_t read_u32() {
std::uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::string read_string(std::uint32_t len) {
std::vector<char> chars(len);
read_raw(chars.data(), len);
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t len) const {
if (len == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, len, 1, fp);
if (ret != 1) {
throw std::runtime_error(format("write error: %s", strerror(errno)));
}
}
void write_u32(std::uint32_t val) {
write_raw(&val, sizeof(val));
}
~llama_file() {
if (fp) {
std::fclose(fp);
}
}
};
#if defined(_WIN32)
static std::string llama_format_win_err(DWORD err) {
LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
if (!size) {
return "FormatMessageA failed";
}
std::string ret(buf, size);
LocalFree(buf);
return ret;
}
#endif
struct llama_mmap {
void * addr;
size_t size;
llama_mmap(const llama_mmap &) = delete;
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
#ifdef __linux__
flags |= MAP_POPULATE;
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) {
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
}
if (prefetch > 0) {
// Advise the kernel to preload the mapped memory
if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
}
~llama_mmap() {
munmap(addr, size);
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, bool prefetch = true) {
size = file->size;
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
DWORD error = GetLastError();
if (hMapping == NULL) {
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
error = GetLastError();
CloseHandle(hMapping);
if (addr == NULL) {
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
}
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
if (prefetch) {
// Advise the kernel to preload the mapped memory
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = (SIZE_T)size;
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
}
~llama_mmap() {
if (!UnmapViewOfFile(addr)) {
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
llama_mmap(struct llama_file *, bool prefetch = true) {
(void)prefetch;
throw std::runtime_error(std::string("mmap not supported"));
}
#endif
};
// Represents some region of memory being locked using mlock or VirtualLock;
// will automatically unlock on destruction.
struct llama_mlock {
void * addr = NULL;
size_t size = 0;
bool failed_already = false;
llama_mlock() {}
llama_mlock(const llama_mlock &) = delete;
~llama_mlock() {
if (size) {
raw_unlock(addr, size);
}
}
void init(void * ptr) {
LLAMA_ASSERT(addr == NULL && size == 0);
addr = ptr;
}
void grow_to(size_t target_size) {
LLAMA_ASSERT(addr);
if (failed_already) {
return;
}
size_t granularity = lock_granularity();
target_size = (target_size + granularity - 1) & ~(granularity - 1);
if (target_size > size) {
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
size = target_size;
} else {
failed_already = true;
}
}
}
#ifdef _POSIX_MEMLOCK_RANGE
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
return (size_t) sysconf(_SC_PAGESIZE);
}
#ifdef __APPLE__
#define MLOCK_SUGGESTION \
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
#else
#define MLOCK_SUGGESTION \
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
#endif
bool raw_lock(const void * addr, size_t size) {
if (!mlock(addr, size)) {
return true;
} else {
char* errmsg = std::strerror(errno);
bool suggest = (errno == ENOMEM);
// Check if the resource limit is fine after all
struct rlimit lock_limit;
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
suggest = false;
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
suggest = false;
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
return false;
}
}
#undef MLOCK_SUGGESTION
void raw_unlock(void * addr, size_t size) {
if (munlock(addr, size)) {
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
}
}
#elif defined(_WIN32)
static constexpr bool SUPPORTED = true;
size_t lock_granularity() {
SYSTEM_INFO si;
GetSystemInfo(&si);
return (size_t) si.dwPageSize;
}
bool raw_lock(void * ptr, size_t len) {
for (int tries = 1; ; tries++) {
if (VirtualLock(ptr, len)) {
return true;
}
if (tries == 2) {
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
len, size, llama_format_win_err(GetLastError()).c_str());
return false;
}
// It failed but this was only the first try; increase the working
// set size and try again.
SIZE_T min_ws_size, max_ws_size;
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
// Per MSDN: "The maximum number of pages that a process can lock
// is equal to the number of pages in its minimum working set minus
// a small overhead."
// Hopefully a megabyte is enough overhead:
size_t increment = len + 1048576;
// The minimum must be <= the maximum, so we need to increase both:
min_ws_size += increment;
max_ws_size += increment;
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
return false;
}
}
}
void raw_unlock(void * ptr, size_t len) {
if (!VirtualUnlock(ptr, len)) {
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
}
#else
static constexpr bool SUPPORTED = false;
size_t lock_granularity() {
return (size_t) 65536;
}
bool raw_lock(const void * addr, size_t len) {
fprintf(stderr, "warning: mlock not supported on this system\n");
return false;
}
void raw_unlock(const void * addr, size_t len) {}
#endif
};
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
struct llama_buffer {
uint8_t * addr = NULL;
size_t size = 0;
llama_buffer() = default;
void resize(size_t len) {
delete[] addr;
addr = new uint8_t[len];
size = len;
}
~llama_buffer() {
delete[] addr;
}
// disable copy and move
llama_buffer(const llama_buffer&) = delete;
llama_buffer(llama_buffer&&) = delete;
llama_buffer& operator=(const llama_buffer&) = delete;
llama_buffer& operator=(llama_buffer&&) = delete;
};
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
struct llama_ctx_buffer {
uint8_t * addr = NULL;
bool is_cuda;
size_t size = 0;
llama_ctx_buffer() = default;
void resize(size_t size) {
free();
addr = (uint8_t *) ggml_cuda_host_malloc(size);
if (addr) {
is_cuda = true;
}
else {
// fall back to pageable memory
addr = new uint8_t[size];
is_cuda = false;
}
this->size = size;
}
void free() {
if (addr) {
if (is_cuda) {
ggml_cuda_host_free(addr);
}
else {
delete[] addr;
}
}
addr = NULL;
}
~llama_ctx_buffer() {
free();
}
// disable copy and move
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
};
#else
typedef llama_buffer llama_ctx_buffer;
#endif
#endif

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,8 @@
#ifndef LLAMA_H
#define LLAMA_H
#include "ggml.h"
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
#else
#define LLAMA_MAX_DEVICES 1
#endif // GGML_USE_CUBLAS
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdbool.h>
#ifdef LLAMA_SHARED
@ -27,25 +19,17 @@
# define LLAMA_API
#endif
#ifdef __GNUC__
# define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
#elif defined(_MSC_VER)
# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
#else
# define DEPRECATED(func, hint) func
#endif
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
#define LLAMA_SUPPORTS_GPU_OFFLOAD
#endif
#define LLAMA_FILE_VERSION 3
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#ifdef __cplusplus
extern "C" {
@ -57,57 +41,10 @@ extern "C" {
// TODO: show sample usage
//
struct llama_model;
struct llama_context;
typedef int llama_token;
enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2,
LLAMA_LOG_LEVEL_WARN = 3,
LLAMA_LOG_LEVEL_INFO = 4
};
enum llama_vocab_type {
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
};
enum llama_token_type {
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
LLAMA_TOKEN_TYPE_CONTROL = 3,
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
LLAMA_TOKEN_TYPE_UNUSED = 5,
LLAMA_TOKEN_TYPE_BYTE = 6,
};
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@ -123,152 +60,67 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
int32_t n_ctx; // text context
int32_t n_batch; // prompt processing batch size
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
int n_ctx; // text context
int n_gpu_layers; // number of layers to store in VRAM
int seed; // RNG seed, -1 for random
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency
float rope_freq_scale; // RoPE frequency scaling factor
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool low_vram; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};
// Signature for logging events
// Note that text includes the new line character at the end for most events.
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
// if it exists.
// It might not exist for progress report where '.' is output repeatedly.
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
// model quantization parameters
typedef struct llama_model_quantize_params {
int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
} llama_model_quantize_params;
// grammar types
struct llama_grammar;
// grammar element type
enum llama_gretype {
// end of rule definition
LLAMA_GRETYPE_END = 0,
// start of alternate definition for rule
LLAMA_GRETYPE_ALT = 1,
// non-terminal element: reference to rule
LLAMA_GRETYPE_RULE_REF = 2,
// terminal element: character (code point)
LLAMA_GRETYPE_CHAR = 3,
// inverse char(s) ([^a], [^a-b] [^abc])
LLAMA_GRETYPE_CHAR_NOT = 4,
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
// be an inclusive range ([a-z])
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
// modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 6,
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
};
typedef struct llama_grammar_element {
enum llama_gretype type;
uint32_t value; // Unicode code point or rule ID
} llama_grammar_element;
LLAMA_API struct llama_context_params llama_context_default_params();
// performance timing information
struct llama_timings {
double t_start_ms;
double t_end_ms;
double t_load_ms;
double t_sample_ms;
double t_p_eval_ms;
double t_eval_ms;
int32_t n_sample;
int32_t n_p_eval;
int32_t n_eval;
};
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported();
// TODO: not great API - very likely to change
// Initialize the llama + ggml backend
// If numa is true, use NUMA optimizations
// Call once at the start of the program
LLAMA_API void llama_backend_init(bool numa);
LLAMA_API void llama_init_backend();
// Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_backend_free(void);
LLAMA_API int64_t llama_time_us();
LLAMA_API struct llama_model * llama_load_model_from_file(
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model,
struct llama_context_params params);
LLAMA_API void llama_free_model(struct llama_model * model);
LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
LLAMA_API int64_t llama_time_us(void);
LLAMA_API int llama_max_devices (void);
LLAMA_API bool llama_mmap_supported (void);
LLAMA_API bool llama_mlock_supported(void);
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
// Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
// Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
// TODO: not great API - very likely to change
// Returns 0 on success
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
LLAMA_API int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
const llama_model_quantize_params * params);
enum llama_ftype ftype,
int nthread);
// Apply a LoRA adapter to a loaded model
// path_base_model is the path to a higher quality model to use as a base for
@ -276,24 +128,17 @@ extern "C" {
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
// will be applied on top of the previous one
// Returns 0 on success
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
LLAMA_API int llama_apply_lora_from_file(
struct llama_context * ctx,
const char * path_lora,
const char * path_base_model,
int n_threads),
"please use llama_model_apply_lora_from_file instead");
LLAMA_API int llama_model_apply_lora_from_file(
const struct llama_model * model,
const char * path_lora,
const char * path_base_model,
int n_threads);
int n_threads);
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
@ -323,19 +168,21 @@ extern "C" {
int n_past,
int n_threads);
// Same as llama_eval, but use float matrix input directly.
LLAMA_API int llama_eval_embd(
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const float * embd,
int n_tokens,
int n_past,
int n_threads);
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
// Export a static computation graph for context of 511 and batch size of 1
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
// parameters here to keep things simple
// IMPORTANT: do not use for anything else other than debugging and testing!
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
@ -348,75 +195,15 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
//
// Vocab
//
LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token);
LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
LLAMA_API llama_token llama_token_nl();
//
// Tokenization
//
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
// Token Id -> Piece.
// Uses the vocabulary in the provided context.
// Does not write null terminator to the buffer.
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
LLAMA_API int llama_token_to_piece(
const struct llama_context * ctx,
llama_token token,
char * buf,
int length);
LLAMA_API int llama_token_to_piece_with_model(
const struct llama_model * model,
llama_token token,
char * buf,
int length);
//
// Grammar
//
LLAMA_API struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
//
// Sampling functions
//
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
@ -424,16 +211,6 @@ extern "C" {
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
LLAMA_API void llama_sample_classifier_free_guidance(
struct llama_context * ctx,
llama_token_data_array * candidates,
struct llama_context * guidance_ctx,
float scale);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
@ -450,9 +227,6 @@ extern "C" {
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@ -474,60 +248,13 @@ extern "C" {
/// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
//
// Beam search
//
struct llama_beam_view {
const llama_token * tokens;
size_t n_tokens;
float p; // Cumulative beam probability (renormalized relative to all beams)
bool eob; // Callback should set this to true when a beam is at end-of-beam.
};
// Passed to beam_search_callback function.
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
// These pointers are valid only during the synchronous callback, so should not be saved.
struct llama_beams_state {
struct llama_beam_view * beam_views;
size_t n_beams; // Number of elements in beam_views[].
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
bool last_call; // True iff this is the last callback invocation.
};
// Type of pointer to the beam_search_callback function.
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
/// @details Deterministically returns entire sentence constructed by a beam search.
/// @param ctx Pointer to the llama_context.
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
/// @param callback_data A pointer that is simply passed back to callback.
/// @param n_beams Number of beams to use.
/// @param n_past Number of tokens already evaluated.
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
/// @param n_threads Number of threads as passed to llama_eval().
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
// Print system information
LLAMA_API const char * llama_print_system_info(void);
// Set callback for all future logging events.
// If this is not called, or NULL is supplied, everything is output on stderr.
LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data);
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
#ifdef __cplusplus
}
#endif
@ -537,11 +264,10 @@ extern "C" {
#include <vector>
#include <string>
struct ggml_tensor;
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
#endif // LLAMA_API_INTERNAL
#endif
#endif // LLAMA_H

0
examples/talk-llama/speak Executable file → Normal file
View File

View File

@ -25,20 +25,6 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return std::string(result.data(), result.size());
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -47,14 +33,14 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool verbose_prompt = false;
std::string person = "Georgi";
@ -249,7 +235,7 @@ int main(int argc, char ** argv) {
// llama init
llama_backend_init(true);
llama_init_backend();
auto lparams = llama_context_default_params();
@ -258,9 +244,7 @@ int main(int argc, char ** argv) {
lparams.seed = 1;
lparams.f16_kv = true;
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lparams);
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lparams);
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
// print some info about the processing
{
@ -283,6 +267,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
@ -293,6 +278,8 @@ int main(int argc, char ** argv) {
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
@ -527,7 +514,7 @@ int main(int argc, char ** argv) {
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_piece(ctx_llama, embd[i]));
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
@ -595,7 +582,7 @@ int main(int argc, char ** argv) {
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(ctx_llama);
logits[llama_token_eos(ctx_llama)] = 0;
logits[llama_token_eos()] = 0;
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -606,13 +593,13 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// apply repeat penalty
const float nl_logit = logits[llama_token_nl(ctx_llama)];
const float nl_logit = logits[llama_token_nl()];
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
embd_inp.data() + std::max(0, n_past - repeat_last_n),
repeat_last_n, repeat_penalty);
logits[llama_token_nl(ctx_llama)] = nl_logit;
logits[llama_token_nl()] = nl_logit;
if (temp <= 0) {
// Greedy sampling
@ -626,22 +613,22 @@ int main(int argc, char ** argv) {
}
}
if (id != llama_token_eos(ctx_llama)) {
if (id != llama_token_eos()) {
// add it to the context
embd.push_back(id);
text_to_speak += llama_token_to_piece(ctx_llama, id);
text_to_speak += llama_token_to_str(ctx_llama, id);
printf("%s", llama_token_to_piece(ctx_llama, id).c_str());
printf("%s", llama_token_to_str(ctx_llama, id));
}
}
{
std::string last_output;
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
last_output += llama_token_to_piece(ctx_llama, embd_inp[i]);
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
}
last_output += llama_token_to_piece(ctx_llama, embd[0]);
last_output += llama_token_to_str(ctx_llama, embd[0]);
for (std::string & antiprompt : antiprompts) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
@ -668,6 +655,8 @@ int main(int argc, char ** argv) {
}
audio.clear();
++n_iter;
}
}
}

View File

@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
```
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
```
## TTS

View File

@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
printf "\n"
fi
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
for model in "${models[@]}"; do
# actual run
@ -83,13 +83,9 @@ for model in "${models[@]}"; do
config="$config COREML"
fi
if [[ $system_info == *"METAL = 1"* ]]; then
config="$config METAL"
fi
commit=$(git rev-parse --short HEAD)
if [ $ret -eq 0 ]; then
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
fi
done

View File

@ -1,222 +0,0 @@
import os
import subprocess
import re
import csv
import wave
import contextlib
import argparse
# Custom action to handle comma-separated list
class ListAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, [int(val) for val in values.split(",")])
parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
# Define the argument to accept a list
parser.add_argument(
"-t",
"--threads",
dest="threads",
action=ListAction,
default=[4],
help="List of thread counts to benchmark (comma-separated, default: 4)",
)
parser.add_argument(
"-p",
"--processors",
dest="processors",
action=ListAction,
default=[1],
help="List of processor counts to benchmark (comma-separated, default: 1)",
)
parser.add_argument(
"-f",
"--filename",
type=str,
default="./samples/jfk.wav",
help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
)
# Parse the command line arguments
args = parser.parse_args()
sample_file = args.filename
threads = args.threads
processors = args.processors
# Define the models, threads, and processor counts to benchmark
models = [
"ggml-tiny.en.bin",
"ggml-tiny.bin",
"ggml-base.en.bin",
"ggml-base.bin",
"ggml-small.en.bin",
"ggml-small.bin",
"ggml-medium.en.bin",
"ggml-medium.bin",
"ggml-large.bin",
]
metal_device = ""
# Initialize a dictionary to hold the results
results = {}
gitHashHeader = "Commit"
modelHeader = "Model"
hardwareHeader = "Hardware"
recordingLengthHeader = "Recording Length (seconds)"
threadHeader = "Thread"
processorCountHeader = "Processor Count"
loadTimeHeader = "Load Time (ms)"
sampleTimeHeader = "Sample Time (ms)"
encodeTimeHeader = "Encode Time (ms)"
decodeTimeHeader = "Decode Time (ms)"
sampleTimePerRunHeader = "Sample Time per Run (ms)"
encodeTimePerRunHeader = "Encode Time per Run (ms)"
decodeTimePerRunHeader = "Decode Time per Run (ms)"
totalTimeHeader = "Total Time (ms)"
def check_file_exists(file: str) -> bool:
return os.path.isfile(file)
def get_git_short_hash() -> str:
try:
return (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode()
.strip()
)
except subprocess.CalledProcessError as e:
return ""
def wav_file_length(file: str = sample_file) -> float:
with contextlib.closing(wave.open(file, "r")) as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
return duration
def extract_metrics(output: str, label: str) -> tuple[float, float]:
match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
time = float(match.group(1)) if match else None
runs = float(match.group(2)) if match else None
return time, runs
def extract_device(output: str) -> str:
match = re.search(r"picking default device: (.*)", output)
device = match.group(1) if match else "Not found"
return device
# Check if the sample file exists
if not check_file_exists(sample_file):
raise FileNotFoundError(f"Sample file {sample_file} not found")
recording_length = wav_file_length()
# Check that all models exist
# Filter out models from list that are not downloaded
filtered_models = []
for model in models:
if check_file_exists(f"models/{model}"):
filtered_models.append(model)
else:
print(f"Model {model} not found, removing from list")
models = filtered_models
# Loop over each combination of parameters
for model in filtered_models:
for thread in threads:
for processor_count in processors:
# Construct the command to run
cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
# Run the command and get the output
process = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
output = ""
while process.poll() is None:
output += process.stdout.read().decode()
# Parse the output
load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
load_time = float(load_time_match.group(1)) if load_time_match else None
metal_device = extract_device(output)
sample_time, sample_runs = extract_metrics(output, "sample time")
encode_time, encode_runs = extract_metrics(output, "encode time")
decode_time, decode_runs = extract_metrics(output, "decode time")
total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
total_time = float(total_time_match.group(1)) if total_time_match else None
model_name = model.replace("ggml-", "").replace(".bin", "")
print(
f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
)
# Store the times in the results dictionary
results[(model_name, thread, processor_count)] = {
loadTimeHeader: load_time,
sampleTimeHeader: sample_time,
encodeTimeHeader: encode_time,
decodeTimeHeader: decode_time,
sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
totalTimeHeader: total_time,
}
# Write the results to a CSV file
with open("benchmark_results.csv", "w", newline="") as csvfile:
fieldnames = [
gitHashHeader,
modelHeader,
hardwareHeader,
recordingLengthHeader,
threadHeader,
processorCountHeader,
loadTimeHeader,
sampleTimeHeader,
encodeTimeHeader,
decodeTimeHeader,
sampleTimePerRunHeader,
encodeTimePerRunHeader,
decodeTimePerRunHeader,
totalTimeHeader,
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
shortHash = get_git_short_hash()
# Sort the results by total time in ascending order
sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
for params, times in sorted_results:
row = {
gitHashHeader: shortHash,
modelHeader: params[0],
hardwareHeader: metal_device,
recordingLengthHeader: recording_length,
threadHeader: params[1],
processorCountHeader: params[2],
}
row.update(times)
writer.writerow(row)

View File

@ -5745,7 +5745,6 @@ inline void ggml_cuda_op_rope(
(void) dst;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}
@ -5781,7 +5780,6 @@ inline void ggml_cuda_op_alibi(
(void) src1;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}

View File

@ -78,7 +78,6 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
@ -90,7 +89,6 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@ -239,7 +237,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
@ -251,7 +248,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@ -313,7 +309,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
@ -325,7 +320,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@ -891,7 +885,6 @@ void ggml_metal_graph_compute(
ne00%32 == 0 &&
ne11 > 1) {
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@ -926,18 +919,15 @@ void ggml_metal_graph_compute(
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F32:
{
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
nrows = 4;
} break;
case GGML_TYPE_F16:
{
nth0 = 32;
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
} else if (false) {
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
nrows = ne11;
} else {

View File

@ -523,79 +523,6 @@ kernel void kernel_mul_mat_q8_0_f32(
}
}
#define N_F32_F32 4
kernel void kernel_mul_mat_f32_f32(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
const int64_t im = tgpig.z;
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
if (ne00 < 128) {
for (int row = 0; row < N_F32_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const float4 * x4 = (device const float4 *)x;
for (int row = 0; row < N_F32_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
kernel void kernel_mul_mat_f16_f32_1row(
device const char * src0,
device const char * src1,
@ -1472,13 +1399,13 @@ kernel void kernel_mul_mat_q4_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01 [[buffer(4)]],
constant int64_t & ne02 [[buffer(5)]],
constant int64_t & ne10 [[buffer(9)]],
constant int64_t & ne12 [[buffer(11)]],
constant int64_t & ne0 [[buffer(15)]],
constant int64_t & ne1 [[buffer(16)]],
constant uint & gqa [[buffer(17)]],
constant int64_t & ne01[[buffer(4)]],
constant int64_t & ne02[[buffer(5)]],
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -2341,7 +2268,6 @@ typedef void (mat_mm_t)(
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;

72
ggml.c
View File

@ -20119,27 +20119,27 @@ const char * gguf_type_name(enum gguf_type type) {
return GGUF_TYPE_NAME[type];
}
int gguf_get_version(const struct gguf_context * ctx) {
int gguf_get_version(struct gguf_context * ctx) {
return ctx->header.version;
}
size_t gguf_get_alignment(const struct gguf_context * ctx) {
size_t gguf_get_alignment(struct gguf_context * ctx) {
return ctx->alignment;
}
size_t gguf_get_data_offset(const struct gguf_context * ctx) {
size_t gguf_get_data_offset(struct gguf_context * ctx) {
return ctx->offset;
}
void * gguf_get_data(const struct gguf_context * ctx) {
void * gguf_get_data(struct gguf_context * ctx) {
return ctx->data;
}
int gguf_get_n_kv(const struct gguf_context * ctx) {
int gguf_get_n_kv(struct gguf_context * ctx) {
return ctx->header.n_kv;
}
int gguf_find_key(const struct gguf_context * ctx, const char * key) {
int gguf_find_key(struct gguf_context * ctx, const char * key) {
// return -1 if key not found
int keyfound = -1;
@ -20155,85 +20155,85 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
return keyfound;
}
const char * gguf_get_key(const struct gguf_context * ctx, int i) {
const char * gguf_get_key(struct gguf_context * ctx, int i) {
return ctx->kv[i].key.data;
}
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
return ctx->kv[i].type;
}
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.arr.type;
}
const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
const void * gguf_get_arr_data(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.arr.data;
}
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
struct gguf_kv * kv = &ctx->kv[key_id];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
return str->data;
}
int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
int gguf_get_arr_n(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.arr.n;
}
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.uint8;
}
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.int8;
}
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.uint16;
}
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.int16;
}
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.uint32;
}
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.int32;
}
float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
float gguf_get_val_f32(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.float32;
}
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.uint64;
}
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.int64;
}
double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
double gguf_get_val_f64(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.float64;
}
bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
return ctx->kv[i].value.bool_;
}
const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
return ctx->kv[i].value.str.data;
}
int gguf_get_n_tensors(const struct gguf_context * ctx) {
int gguf_get_n_tensors(struct gguf_context * ctx) {
return ctx->header.n_tensors;
}
int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
// return -1 if tensor not found
int tensorfound = -1;
@ -20249,11 +20249,11 @@ int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
return tensorfound;
}
size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) {
return ctx->infos[i].offset;
}
char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
char * gguf_get_tensor_name(struct gguf_context * ctx, int i) {
return ctx->infos[i].name.data;
}
@ -20536,7 +20536,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
buf->offset += el_size;
}
static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
// write header
gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
@ -20651,7 +20651,7 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
}
}
void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) {
FILE * file = fopen(fname, "wb");
if (!file) {
GGML_ASSERT(false && "failed to open file for writing");
@ -20668,7 +20668,7 @@ void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
fclose(file);
}
size_t gguf_get_meta_size(const struct gguf_context * ctx) {
size_t gguf_get_meta_size(struct gguf_context * ctx) {
// no allocs - only compute size
struct gguf_buf buf = gguf_buf_init(0);
@ -20677,7 +20677,7 @@ size_t gguf_get_meta_size(const struct gguf_context * ctx) {
return buf.offset;
}
void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
void gguf_get_meta_data(struct gguf_context * ctx, void * data) {
struct gguf_buf buf = gguf_buf_init(16*1024);
gguf_write_to_buf(ctx, &buf, true);
@ -20753,14 +20753,6 @@ int ggml_cpu_has_arm_fma(void) {
#endif
}
int ggml_cpu_has_metal(void) {
#if defined(GGML_USE_METAL)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_f16c(void) {
#if defined(__F16C__)
return 1;

72
ggml.h
View File

@ -195,14 +195,6 @@
# define GGML_DEPRECATED(func, hint) func
#endif
#ifndef __GNUC__
# define GGML_ATTRIBUTE_FORMAT(...)
#elif defined(__MINGW32__)
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
@ -693,7 +685,6 @@ extern "C" {
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
GGML_ATTRIBUTE_FORMAT(2, 3)
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
//
@ -1875,39 +1866,39 @@ extern "C" {
GGML_API const char * gguf_type_name(enum gguf_type type);
GGML_API int gguf_get_version (const struct gguf_context * ctx);
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
GGML_API void * gguf_get_data (const struct gguf_context * ctx);
GGML_API int gguf_get_version (struct gguf_context * ctx);
GGML_API size_t gguf_get_alignment (struct gguf_context * ctx);
GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
GGML_API void * gguf_get_data (struct gguf_context * ctx);
GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
// results are undefined if the wrong type is used for the key
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i);
GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i);
GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i);
GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i);
GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i);
GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i);
GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i);
GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i);
GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i);
GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);
GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
GGML_API int gguf_get_n_tensors (struct gguf_context * ctx);
GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name);
GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);
GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i);
// overrides existing values or adds a new one
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
@ -1952,11 +1943,11 @@ extern "C" {
//
// write the entire context to a binary file
GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);
GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data);
//
// system info
@ -1970,7 +1961,6 @@ extern "C" {
GGML_API int ggml_cpu_has_fma (void);
GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_arm_fma (void);
GGML_API int ggml_cpu_has_metal (void);
GGML_API int ggml_cpu_has_f16c (void);
GGML_API int ggml_cpu_has_fp16_va (void);
GGML_API int ggml_cpu_has_wasm_simd (void);

View File

@ -1,117 +0,0 @@
import argparse
import importlib.util
spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py')
whisper_to_coreml = importlib.util.module_from_spec(spec)
spec.loader.exec_module(whisper_to_coreml)
from whisper import load_model
from copy import deepcopy
import torch
from transformers import WhisperForConditionalGeneration
from huggingface_hub import metadata_update
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
WHISPER_MAPPING = {
"layers": "blocks",
"fc1": "mlp.0",
"fc2": "mlp.2",
"final_layer_norm": "mlp_ln",
"layers": "blocks",
".self_attn.q_proj": ".attn.query",
".self_attn.k_proj": ".attn.key",
".self_attn.v_proj": ".attn.value",
".self_attn_layer_norm": ".attn_ln",
".self_attn.out_proj": ".attn.out",
".encoder_attn.q_proj": ".cross_attn.query",
".encoder_attn.k_proj": ".cross_attn.key",
".encoder_attn.v_proj": ".cross_attn.value",
".encoder_attn_layer_norm": ".cross_attn_ln",
".encoder_attn.out_proj": ".cross_attn.out",
"decoder.layer_norm.": "decoder.ln.",
"encoder.layer_norm.": "encoder.ln_post.",
"embed_tokens": "token_embedding",
"encoder.embed_positions.weight": "encoder.positional_embedding",
"decoder.embed_positions.weight": "decoder.positional_embedding",
"layer_norm": "ln_post",
}
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
def rename_keys(s_dict):
keys = list(s_dict.keys())
for key in keys:
new_key = key
for k, v in WHISPER_MAPPING.items():
if k in key:
new_key = new_key.replace(k, v)
print(f"{key} -> {new_key}")
s_dict[new_key] = s_dict.pop(key)
return s_dict
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
config = transformer_model.config
# first build dims
dims = {
'n_mels': config.num_mel_bins,
'n_vocab': config.vocab_size,
'n_audio_ctx': config.max_source_positions,
'n_audio_state': config.d_model,
'n_audio_head': config.encoder_attention_heads,
'n_audio_layer': config.encoder_layers,
'n_text_ctx': config.max_target_positions,
'n_text_state': config.d_model,
'n_text_head': config.decoder_attention_heads,
'n_text_layer': config.decoder_layers
}
state_dict = deepcopy(transformer_model.model.state_dict())
state_dict = rename_keys(state_dict)
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
# Ported from models/convert-whisper-to-coreml.py
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
raise ValueError("Invalid model name")
pt_target_path = f"models/hf-{args.model_name}.pt"
convert_hf_whisper(args.model_path, pt_target_path)
whisper = load_model(pt_target_path).cpu()
hparams = whisper.dims
print(hparams)
if args.optimize_ane:
whisperANE = whisper_to_coreml.WhisperANE(hparams).eval()
whisperANE.load_state_dict(whisper.state_dict())
encoder = whisperANE.encoder
decoder = whisperANE.decoder
else:
encoder = whisper.encoder
decoder = whisper.decoder
# Convert encoder
encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize)
encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage")
if args.encoder_only is False:
# Convert decoder
decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize)
decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage")
print("done converting")

View File

@ -29,7 +29,7 @@ def convert_encoder(hparams, encoder, mname):
# use model optimizer to convert onnx to OpenVINO IR format
encoder_model = mo.convert_model(onnx_path, compress_to_fp16=True)
serialize(encoder_model, xml_path=os.path.join(os.path.dirname(__file__),"ggml-" + mname + "-encoder-openvino.xml"))
serialize(encoder_model, xml_path='ggml-' + mname + '-encoder-openvino.xml')
#cleanup
if os.path.isdir(onnx_folder):

View File

@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
goto :eof
)
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Start-BitsTransfer -Source https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -Destination ggml-%model%.bin"
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
if %ERRORLEVEL% neq 0 (
echo Failed to download ggml model %model%

View File

@ -1,15 +1,11 @@
#!/bin/bash
# Usage: ./generate-coreml-model.sh <model-name>
if [ $# -eq 0 ]; then
echo "No model name supplied"
echo "Usage for Whisper models: ./generate-coreml-model.sh <model-name>"
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
exit 1
elif [[ "$1" == "-h5" && $# != 3 ]]; then
echo "No model name and model path supplied for a HuggingFace model"
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
exit 1
if [ $# -eq 0 ]
then
echo "No model name supplied"
echo "Usage: ./generate-coreml-model.sh <model-name>"
exit 1
fi
mname="$1"
@ -17,14 +13,7 @@ mname="$1"
wd=$(dirname "$0")
cd "$wd/../"
if [[ $mname == "-h5" ]]; then
mname="$2"
mpath="$3"
echo $mpath
python3 models/convert-h5-to-coreml.py --model-name $mname --model-path $mpath --encoder-only True
else
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
fi
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
rm -rf models/ggml-${mname}-encoder.mlmodelc

View File

@ -125,17 +125,9 @@ static void byteswap_tensor(ggml_tensor * tensor) {
// ggml helpers
//
static void ggml_graph_compute_helper(
std::vector<uint8_t> & buf,
ggml_cgraph * graph,
int n_threads,
whisper_abort_callback abort_callback,
void * abort_callback_data) {
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
plan.abort_callback = abort_callback;
plan.abort_callback_data = abort_callback_data;
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
@ -1930,9 +1922,7 @@ static bool whisper_encode_internal(
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
const int n_threads,
whisper_abort_callback abort_callback,
void * abort_callback_data) {
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
// conv
@ -1946,7 +1936,7 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);
if (!whisper_encode_external(wstate)) {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
}
@ -1965,10 +1955,10 @@ static bool whisper_encode_internal(
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
#else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif
}
@ -1987,10 +1977,10 @@ static bool whisper_encode_internal(
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
#else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif
}
@ -2356,9 +2346,7 @@ static bool whisper_decode_internal(
const whisper_token * tokens,
const int n_tokens,
const int n_past,
const int n_threads,
whisper_abort_callback abort_callback,
void * abort_callback_data) {
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
@ -2387,10 +2375,10 @@ static bool whisper_decode_internal(
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
ggml_metal_graph_compute(wstate.ctx_metal, gf);
} else {
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
}
#else
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
#endif
}
@ -3302,7 +3290,7 @@ int whisper_set_mel(
}
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
log("%s: failed to eval\n", __func__);
return -1;
}
@ -3311,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
log("%s: failed to eval\n", __func__);
return -1;
}
@ -3322,7 +3310,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int selected_decoder_id = 0;
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
log("%s: failed to eval\n", __func__);
return 1;
}
@ -3339,7 +3327,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return false;
}
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
log("%s: failed to eval\n", __func__);
return 1;
}
@ -3681,7 +3669,6 @@ const char * whisper_print_system_info(void) {
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
@ -3773,9 +3760,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
/*.abort_callback =*/ nullptr,
/*.abort_callback_user_data =*/ nullptr,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};
@ -3942,7 +3926,6 @@ static void whisper_process_logits(
// suppress task tokens
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
logits[vocab.token_prev] = -INFINITY;
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@ -4536,7 +4519,7 @@ int whisper_full_with_state(
// initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
prompt_tokens.resize(2048);
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
@ -4561,7 +4544,6 @@ int whisper_full_with_state(
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
state->lang_id = lang_id;
@ -4573,17 +4555,6 @@ int whisper_full_with_state(
}
}
{
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
// distilled models require the "no_timestamps" token
// TODO: add input parameter (#1229)
if (is_distil) {
log("%s: using distilled model - forcing no_timestamps\n", __func__);
prompt_init.push_back(whisper_token_not(ctx));
}
}
int seek = seek_start;
std::vector<whisper_token> prompt;
@ -4622,7 +4593,7 @@ int whisper_full_with_state(
}
// encode audio features starting at offset seek
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
log("%s: failed to encode\n", __func__);
return -6;
}
@ -4705,7 +4676,7 @@ int whisper_full_with_state(
}
WHISPER_PRINT_DEBUG("\n\n");
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
log("%s: failed to decode\n", __func__);
return -7;
}
@ -4929,7 +4900,7 @@ int whisper_full_with_state(
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
log("%s: failed to decode\n", __func__);
return -8;
}
@ -5298,10 +5269,6 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
return ctx->state->result_all[i_segment].t1;
}
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].speaker_turn_next;
}
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].speaker_turn_next;
}
@ -5501,12 +5468,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
double tsum = 0.0;
// heat-up
ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
ggml_graph_compute_helper(work, &gf, n_threads);
for (int i = 0; i < n_max; ++i) {
const int64_t t0 = ggml_time_us();
ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
ggml_graph_compute_helper(work, &gf, n_threads);
const int64_t t1 = ggml_time_us();

View File

@ -334,11 +334,6 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*whisper_abort_callback)(void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
@ -433,10 +428,6 @@ extern "C" {
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called each time before ggml computation starts
whisper_abort_callback abort_callback;
void * abort_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
@ -494,7 +485,6 @@ extern "C" {
// Get whether the next segment is predicted as a speaker turn
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
// Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);