mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-10 22:15:09 +02:00
Compare commits
44 Commits
distil-sup
...
metal-and-
Author | SHA1 | Date | |
---|---|---|---|
3ac0558009 | |||
a1664574fe | |||
bfcb2a2ab9 | |||
0d5e4cdc36 | |||
2b4160af29 | |||
f36554382a | |||
c46167f8c5 | |||
af947cb72e | |||
e81c67a125 | |||
f408c64564 | |||
d863f725a1 | |||
d37f56e7a9 | |||
23277d21ce | |||
ecb23fb1eb | |||
8e8daa8451 | |||
16db4da3f1 | |||
257d7942af | |||
181bb8cb28 | |||
796f84cd95 | |||
77f4bf49c8 | |||
b6f09669a2 | |||
254b687239 | |||
b19888cfb4 | |||
905c944143 | |||
3074a7ff14 | |||
ec9a7db74c | |||
cd476375b4 | |||
9fdd415367 | |||
79a88057bd | |||
fbc9ddc582 | |||
3b9979a373 | |||
de94c783ee | |||
d3b2dd4955 | |||
4845b9ed09 | |||
2770d46ef5 | |||
4d9acc60c3 | |||
06d1d2836b | |||
9a78b72246 | |||
794e8fe0ea | |||
fa672b46e6 | |||
af6f67b251 | |||
bed5ad69dd | |||
949ab6328d | |||
fbc3f8033e |
@ -1,28 +0,0 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=11.7.1
|
||||
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} as build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
ARG CUDA_DOCKER_ARCH=all
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git cmake
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
# Set nvcc architecture
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
# Enable cuBLAS
|
||||
ENV WHISPER_CUBLAS=1
|
||||
|
||||
RUN make
|
||||
|
||||
ENTRYPOINT ["/app/main"]
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -46,5 +46,3 @@ models/*.mlpackage
|
||||
bindings/java/.gradle/
|
||||
bindings/java/.idea/
|
||||
.idea/
|
||||
|
||||
benchmark_results.csv
|
||||
|
@ -117,7 +117,7 @@ if (APPLE)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
|
||||
else()
|
||||
message(FATAL_ERROR "Accelerate framework not found")
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -140,7 +140,7 @@ if (APPLE)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "Metal framework not found")
|
||||
message(WARNING "Metal framework not found")
|
||||
endif()
|
||||
|
||||
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
|
||||
@ -158,7 +158,7 @@ if (APPLE)
|
||||
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
|
||||
else()
|
||||
message(FATAL_ERROR "CoreML framework not found")
|
||||
message(WARNING "CoreML framework not found")
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML_ALLOW_FALLBACK)
|
||||
@ -181,13 +181,13 @@ if (WHISPER_BLAS)
|
||||
include_directories($ENV{OPENBLAS_PATH}/include)
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
||||
else ()
|
||||
message(FATAL_ERROR "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
|
||||
message(WARNING "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
|
||||
endif ()
|
||||
else ()
|
||||
set(BLA_STATIC 1)
|
||||
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
|
||||
# set(BLA_PREFER_PKGCONFIG 1)
|
||||
set(BLA_SIZEOF_INTEGER 8)
|
||||
set(BLA_PREFER_PKGCONFIG 1)
|
||||
find_package(BLAS)
|
||||
|
||||
if(BLAS_FOUND)
|
||||
@ -198,7 +198,7 @@ if (WHISPER_BLAS)
|
||||
include_directories(${BLAS_INCLUDE_DIRS})
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
|
||||
else()
|
||||
message(FATAL_ERROR "BLAS library was not found")
|
||||
message(WARNING "BLAS library was not found")
|
||||
endif()
|
||||
endif ()
|
||||
endif ()
|
||||
@ -224,7 +224,7 @@ if (WHISPER_CUBLAS)
|
||||
endif()
|
||||
|
||||
else()
|
||||
message(FATAL_ERROR "cuBLAS not found")
|
||||
message(WARNING "cuBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -255,7 +255,7 @@ if (WHISPER_HIPBLAS)
|
||||
endif()
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
|
||||
else()
|
||||
message(FATAL_ERROR "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
|
||||
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -270,7 +270,7 @@ if (WHISPER_CLBLAST)
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} clblast)
|
||||
else()
|
||||
message(FATAL_ERROR "CLBlast not found")
|
||||
message(WARNING "CLBlast not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
1
Makefile
1
Makefile
@ -186,7 +186,6 @@ ifndef WHISPER_NO_METAL
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
WHISPER_METAL := 1
|
||||
|
||||
CFLAGS += -DGGML_USE_METAL
|
||||
CXXFLAGS += -DGGML_USE_METAL
|
||||
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||
endif
|
||||
|
32
README.md
32
README.md
@ -50,7 +50,7 @@ You can also easily make your own offline voice assistant application: [command]
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
On Apple Silicon, the inference runs fully on the GPU via Metal:
|
||||
On Apply Silicon, the inference runs fully on the GPU via Metal:
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
|
||||
|
||||
@ -113,37 +113,30 @@ options:
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||
-bo N, --best-of N [5 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-otxt, --output-txt [false ] output result in a text file
|
||||
-ovtt, --output-vtt [false ] output result in a vtt file
|
||||
-osrt, --output-srt [false ] output result in a srt file
|
||||
-olrc, --output-lrc [false ] output result in a lrc file
|
||||
-owts, --output-words [false ] output script for generating karaoke video
|
||||
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
|
||||
-ocsv, --output-csv [false ] output result in a CSV file
|
||||
-oj, --output-json [false ] output result in a JSON file
|
||||
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [false ] do not print timestamps
|
||||
-nt, --no-timestamps [true ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
-dl, --detect-language [false ] exit after automatically detecting language
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-f FNAME, --file FNAME [ ] input WAV file path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
-ls, --log-score [false ] log best decoder scores of token
|
||||
|
||||
|
||||
bash ./models/download-ggml-model.sh base.en
|
||||
@ -709,19 +702,6 @@ took to execute it. The results are summarized in the following Github issue:
|
||||
|
||||
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
|
||||
|
||||
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](bench.py).
|
||||
|
||||
You can run it with the following command, by default it will run against any standard model in the models folder.
|
||||
|
||||
```bash
|
||||
python3 extra/bench.py -f samples/jfk.wav -t 2,4,8 -p 1,2
|
||||
```
|
||||
|
||||
It is written in python with the intention of being easy to modify and extend for your benchmarking use case.
|
||||
|
||||
It outputs a csv file with the results of the benchmarking.
|
||||
|
||||
|
||||
## ggml format
|
||||
|
||||
The original models are converted to a custom binary format. This allows to pack everything needed into a single file:
|
||||
@ -743,7 +723,7 @@ in [models](models).
|
||||
## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
|
||||
|
||||
- [X] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
|
||||
- [X] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
|
||||
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
|
||||
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
|
||||
- [X] Java:
|
||||
|
@ -118,11 +118,6 @@ func (p *Params) SetMaxTokensPerSegment(n int) {
|
||||
p.max_tokens = C.int(n)
|
||||
}
|
||||
|
||||
// Set audio encoder context
|
||||
func (p *Params) SetAudioCtx(n int) {
|
||||
p.audio_ctx = C.int(n)
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
@ -146,7 +141,6 @@ func (p *Params) String() string {
|
||||
str += fmt.Sprintf(" n_max_text_ctx=%d", p.n_max_text_ctx)
|
||||
str += fmt.Sprintf(" offset_ms=%d", p.offset_ms)
|
||||
str += fmt.Sprintf(" duration_ms=%d", p.duration_ms)
|
||||
str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx)
|
||||
if p.translate {
|
||||
str += " translate"
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ func (context *context) SetSpeedup(v bool) {
|
||||
}
|
||||
|
||||
func (context *context) SetSplitOnWord(v bool) {
|
||||
context.params.SetSplitOnWord(v)
|
||||
context.params.SetSplitOnWord(v)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
@ -125,11 +125,6 @@ func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// Set audio encoder context
|
||||
func (context *context) SetAudioCtx(n uint) {
|
||||
context.params.SetAudioCtx(int(n))
|
||||
}
|
||||
|
||||
// ResetTimings resets the mode timings. Should be called before processing
|
||||
func (context *context) ResetTimings() {
|
||||
context.model.ctx.Whisper_reset_timings()
|
||||
|
@ -48,7 +48,6 @@ type Context interface {
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetTokenTimestamps(bool) // Set token timestamps flag
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
SetAudioCtx(uint) // Set audio encoder context
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
|
2
bindings/ruby/ext/.gitignore
vendored
2
bindings/ruby/ext/.gitignore
vendored
@ -1,8 +1,6 @@
|
||||
Makefile
|
||||
ggml.c
|
||||
ggml.h
|
||||
ggml-alloc.c
|
||||
ggml-alloc.h
|
||||
whisper.bundle
|
||||
whisper.cpp
|
||||
whisper.h
|
||||
|
@ -3,8 +3,6 @@ system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.h')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-alloc.c')} .")
|
||||
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .")
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include "whisper.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
|
@ -7,8 +7,6 @@
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
|
||||
#define COMMON_SAMPLE_RATE 16000
|
||||
|
||||
@ -141,104 +139,6 @@ bool read_wav(
|
||||
std::vector<std::vector<float>> & pcmf32s,
|
||||
bool stereo);
|
||||
|
||||
// Write PCM data into WAV audio file
|
||||
class wav_writer {
|
||||
private:
|
||||
std::ofstream file;
|
||||
uint32_t dataSize = 0;
|
||||
std::string wav_filename;
|
||||
|
||||
bool write_header(const uint32_t sample_rate,
|
||||
const uint16_t bits_per_sample,
|
||||
const uint16_t channels) {
|
||||
|
||||
file.write("RIFF", 4);
|
||||
file.write("\0\0\0\0", 4); // Placeholder for file size
|
||||
file.write("WAVE", 4);
|
||||
file.write("fmt ", 4);
|
||||
|
||||
const uint32_t sub_chunk_size = 16;
|
||||
const uint16_t audio_format = 1; // PCM format
|
||||
const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8;
|
||||
const uint16_t block_align = channels * bits_per_sample / 8;
|
||||
|
||||
file.write(reinterpret_cast<const char *>(&sub_chunk_size), 4);
|
||||
file.write(reinterpret_cast<const char *>(&audio_format), 2);
|
||||
file.write(reinterpret_cast<const char *>(&channels), 2);
|
||||
file.write(reinterpret_cast<const char *>(&sample_rate), 4);
|
||||
file.write(reinterpret_cast<const char *>(&byte_rate), 4);
|
||||
file.write(reinterpret_cast<const char *>(&block_align), 2);
|
||||
file.write(reinterpret_cast<const char *>(&bits_per_sample), 2);
|
||||
file.write("data", 4);
|
||||
file.write("\0\0\0\0", 4); // Placeholder for data size
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// It is assumed that PCM data is normalized to a range from -1 to 1
|
||||
bool write_audio(const float * data, size_t length) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
|
||||
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
||||
dataSize += sizeof(int16_t);
|
||||
}
|
||||
if (file.is_open()) {
|
||||
file.seekp(4, std::ios::beg);
|
||||
uint32_t fileSize = 36 + dataSize;
|
||||
file.write(reinterpret_cast<char *>(&fileSize), 4);
|
||||
file.seekp(40, std::ios::beg);
|
||||
file.write(reinterpret_cast<char *>(&dataSize), 4);
|
||||
file.seekp(0, std::ios::end);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool open_wav(const std::string & filename) {
|
||||
if (filename != wav_filename) {
|
||||
if (file.is_open()) {
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
if (!file.is_open()) {
|
||||
file.open(filename, std::ios::binary);
|
||||
wav_filename = filename;
|
||||
dataSize = 0;
|
||||
}
|
||||
return file.is_open();
|
||||
}
|
||||
|
||||
public:
|
||||
bool open(const std::string & filename,
|
||||
const uint32_t sample_rate,
|
||||
const uint16_t bits_per_sample,
|
||||
const uint16_t channels) {
|
||||
|
||||
if (open_wav(filename)) {
|
||||
write_header(sample_rate, bits_per_sample, channels);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool close() {
|
||||
file.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool write(const float * data, size_t length) {
|
||||
return write_audio(data, length);
|
||||
}
|
||||
|
||||
~wav_writer() {
|
||||
if (file.is_open()) {
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Apply a high-pass frequency filter to PCM audio
|
||||
// Suppresses frequencies below cutoff Hz
|
||||
void high_pass_filter(
|
||||
|
@ -83,7 +83,6 @@ struct whisper_params {
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool output_jsn = false;
|
||||
bool output_jsn_full = false;
|
||||
bool output_lrc = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
@ -152,7 +151,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
@ -208,7 +206,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
|
||||
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
|
||||
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
|
||||
fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
|
||||
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
@ -514,12 +511,7 @@ bool output_score(struct whisper_context * ctx, const char * fname, const whispe
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_json(
|
||||
struct whisper_context * ctx,
|
||||
const char * fname,
|
||||
const whisper_params & params,
|
||||
std::vector<std::vector<float>> pcmf32s,
|
||||
bool full) {
|
||||
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
int indent = 0;
|
||||
|
||||
@ -536,7 +528,7 @@ bool output_json(
|
||||
auto end_arr = [&](bool end) {
|
||||
indent--;
|
||||
doindent();
|
||||
fout << (end ? "]\n" : "],\n");
|
||||
fout << (end ? "]\n" : "},\n");
|
||||
};
|
||||
|
||||
auto start_obj = [&](const char *name) {
|
||||
@ -577,29 +569,12 @@ bool output_json(
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
auto value_f = [&](const char *name, const float val, bool end) {
|
||||
start_value(name);
|
||||
fout << val;
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
auto value_b = [&](const char *name, const bool val, bool end) {
|
||||
start_value(name);
|
||||
fout << (val ? "true" : "false");
|
||||
end_value(end);
|
||||
};
|
||||
|
||||
auto times_o = [&](int64_t t0, int64_t t1, bool end) {
|
||||
start_obj("timestamps");
|
||||
value_s("from", to_timestamp(t0, true).c_str(), false);
|
||||
value_s("to", to_timestamp(t1, true).c_str(), true);
|
||||
end_obj(false);
|
||||
start_obj("offsets");
|
||||
value_i("from", t0 * 10, false);
|
||||
value_i("to", t1 * 10, true);
|
||||
end_obj(end);
|
||||
};
|
||||
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
return false;
|
||||
@ -645,26 +620,15 @@ bool output_json(
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
start_obj(nullptr);
|
||||
times_o(t0, t1, false);
|
||||
value_s("text", text, !params.diarize && !params.tinydiarize && !full);
|
||||
|
||||
if (full) {
|
||||
start_arr("tokens");
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
start_obj(nullptr);
|
||||
value_s("text", whisper_token_to_str(ctx, token.id), false);
|
||||
if(token.t0 > -1 && token.t1 > -1) {
|
||||
// If we have per-token timestamps, write them out
|
||||
times_o(token.t0, token.t1, false);
|
||||
}
|
||||
value_i("id", token.id, false);
|
||||
value_f("p", token.p, true);
|
||||
end_obj(j == (n - 1));
|
||||
}
|
||||
end_arr(!params.diarize && !params.tinydiarize);
|
||||
}
|
||||
start_obj("timestamps");
|
||||
value_s("from", to_timestamp(t0, true).c_str(), false);
|
||||
value_s("to", to_timestamp(t1, true).c_str(), true);
|
||||
end_obj(false);
|
||||
start_obj("offsets");
|
||||
value_i("from", t0 * 10, false);
|
||||
value_i("to", t1 * 10, true);
|
||||
end_obj(false);
|
||||
value_s("text", text, !params.diarize && !params.tinydiarize);
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
|
||||
@ -948,7 +912,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
|
||||
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
@ -980,9 +944,8 @@ int main(int argc, char ** argv) {
|
||||
wparams.progress_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
// examples for abort mechanism
|
||||
// in examples below, we do not abort the processing, but we could if the flag is set to true
|
||||
|
||||
// example for abort mechanism
|
||||
// in this example, we do not abort the processing, but we could if the flag is set to true
|
||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
@ -994,17 +957,6 @@ int main(int argc, char ** argv) {
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
// the callback is called before every computation - if it returns true, the computation is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.abort_callback = [](void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return is_aborted;
|
||||
};
|
||||
wparams.abort_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||
return 10;
|
||||
@ -1048,7 +1000,7 @@ int main(int argc, char ** argv) {
|
||||
// output to JSON file
|
||||
if (params.output_jsn) {
|
||||
const auto fname_jsn = fname_out + ".json";
|
||||
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
|
||||
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
|
||||
}
|
||||
|
||||
// output to LRC file
|
||||
|
@ -39,20 +39,6 @@ brew install sdl2
|
||||
make stream
|
||||
```
|
||||
|
||||
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
|
||||
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
|
||||
`examples/steam` means your compiler cannot find them and it gives an error it cannot find the file.
|
||||
|
||||
```bash
|
||||
whisper.cpp/examples/stream$ make stream
|
||||
g++ stream.cpp -o stream
|
||||
stream.cpp:6:10: fatal error: common/sdl.h: No such file or directory
|
||||
6 | #include "common/sdl.h"
|
||||
| ^~~~~~~~~~~~~~
|
||||
compilation terminated.
|
||||
make: *** [<builtin>: stream] Error 1
|
||||
```
|
||||
|
||||
## Web version
|
||||
|
||||
This tool can also run in the browser: [examples/stream.wasm](/examples/stream.wasm)
|
||||
|
@ -2,6 +2,7 @@
|
||||
//
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
@ -13,7 +14,6 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t) {
|
||||
@ -52,7 +52,6 @@ struct whisper_params {
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out;
|
||||
bool save_audio = false; // save audio to wav file
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -83,7 +82,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
||||
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
||||
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
@ -119,7 +117,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -157,6 +154,7 @@ int main(int argc, char ** argv) {
|
||||
audio.resume();
|
||||
|
||||
// whisper init
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1){
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -214,28 +212,14 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
wav_writer wavWriter;
|
||||
// save wav file
|
||||
if (params.save_audio) {
|
||||
// Get current date/time for filename
|
||||
time_t now = time(0);
|
||||
char buffer[80];
|
||||
strftime(buffer, sizeof(buffer), "%Y%m%d%H%M%S", localtime(&now));
|
||||
std::string filename = std::string(buffer) + ".wav";
|
||||
|
||||
wavWriter.open(filename, WHISPER_SAMPLE_RATE, 16, 1);
|
||||
}
|
||||
printf("[Start speaking]\n");
|
||||
printf("[Start speaking]");
|
||||
fflush(stdout);
|
||||
|
||||
auto t_last = std::chrono::high_resolution_clock::now();
|
||||
auto t_last = std::chrono::high_resolution_clock::now();
|
||||
const auto t_start = t_last;
|
||||
|
||||
// main audio loop
|
||||
while (is_running) {
|
||||
if (params.save_audio) {
|
||||
wavWriter.write(pcmf32_new.data(), pcmf32_new.size());
|
||||
}
|
||||
// handle Ctrl + C
|
||||
is_running = sdl_poll_events();
|
||||
|
||||
@ -387,7 +371,7 @@ int main(int argc, char ** argv) {
|
||||
fout << std::endl;
|
||||
}
|
||||
|
||||
if (use_vad) {
|
||||
if (use_vad){
|
||||
printf("\n");
|
||||
printf("### Transcription %d END\n", n_iter);
|
||||
}
|
||||
@ -424,4 +408,4 @@ int main(int argc, char ** argv) {
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
@ -2,12 +2,6 @@
|
||||
|
||||
Talk with an LLaMA AI in your terminal
|
||||
|
||||
*Latest perf as of 2 Nov 2023 using Whisper Medium + LLaMA v2 13B Q8_0 on M2 Ultra:*
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/d97a3788-bf2a-4756-9a43-60c6b391649e
|
||||
|
||||
*Previous demo running on CPUs*
|
||||
|
||||
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
|
||||
|
||||
## Building
|
||||
@ -25,7 +19,7 @@ brew install sdl2
|
||||
make talk-llama
|
||||
|
||||
# Run it
|
||||
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
|
||||
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
||||
```
|
||||
|
||||
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
|
||||
@ -42,7 +36,7 @@ This feature is especially helpful for maintaining context in long conversations
|
||||
Example usage:
|
||||
|
||||
```bash
|
||||
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
|
||||
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
474
examples/talk-llama/llama-util.h
Normal file
474
examples/talk-llama/llama-util.h
Normal file
@ -0,0 +1,474 @@
|
||||
// Internal header to be included only by llama.cpp.
|
||||
// Contains wrappers around OS interfaces.
|
||||
|
||||
#ifndef LLAMA_UTIL_H
|
||||
#define LLAMA_UTIL_H
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cerrno>
|
||||
#include <cstring>
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <climits>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdexcept>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/mman.h>
|
||||
#endif
|
||||
#if defined(_POSIX_MEMLOCK_RANGE)
|
||||
#include <sys/resource.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <io.h>
|
||||
#include <stdio.h> // for _fseeki64
|
||||
#endif
|
||||
|
||||
#define LLAMA_ASSERT(x) \
|
||||
do { \
|
||||
if (!(x)) { \
|
||||
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||
abort(); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#ifdef __GNUC__
|
||||
#ifdef __MINGW32__
|
||||
__attribute__((format(gnu_printf, 1, 2)))
|
||||
#else
|
||||
__attribute__((format(printf, 1, 2)))
|
||||
#endif
|
||||
#endif
|
||||
static std::string format(const char * fmt, ...) {
|
||||
va_list ap, ap2;
|
||||
va_start(ap, fmt);
|
||||
va_copy(ap2, ap);
|
||||
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||
LLAMA_ASSERT(size >= 0 && size < INT_MAX);
|
||||
std::vector<char> buf(size + 1);
|
||||
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||
LLAMA_ASSERT(size2 == size);
|
||||
va_end(ap2);
|
||||
va_end(ap);
|
||||
return std::string(buf.data(), size);
|
||||
}
|
||||
|
||||
struct llama_file {
|
||||
// use FILE * so we don't have to re-open the file to mmap
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
|
||||
llama_file(const char * fname, const char * mode) {
|
||||
fp = std::fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
|
||||
}
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
seek(0, SEEK_SET);
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
LLAMA_ASSERT(ret != -1); // this really shouldn't fail
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) {
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
LLAMA_ASSERT(ret == 0); // same
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t len) const {
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, len, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
throw std::runtime_error(format("read error: %s", strerror(errno)));
|
||||
}
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error(std::string("unexpectedly reached end of file"));
|
||||
}
|
||||
}
|
||||
|
||||
std::uint32_t read_u32() {
|
||||
std::uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string read_string(std::uint32_t len) {
|
||||
std::vector<char> chars(len);
|
||||
read_raw(chars.data(), len);
|
||||
return std::string(chars.data(), len);
|
||||
}
|
||||
|
||||
void write_raw(const void * ptr, size_t len) const {
|
||||
if (len == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
size_t ret = std::fwrite(ptr, len, 1, fp);
|
||||
if (ret != 1) {
|
||||
throw std::runtime_error(format("write error: %s", strerror(errno)));
|
||||
}
|
||||
}
|
||||
|
||||
void write_u32(std::uint32_t val) {
|
||||
write_raw(&val, sizeof(val));
|
||||
}
|
||||
|
||||
~llama_file() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(_WIN32)
|
||||
static std::string llama_format_win_err(DWORD err) {
|
||||
LPSTR buf;
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
|
||||
if (!size) {
|
||||
return "FormatMessageA failed";
|
||||
}
|
||||
std::string ret(buf, size);
|
||||
LocalFree(buf);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct llama_mmap {
|
||||
void * addr;
|
||||
size_t size;
|
||||
|
||||
llama_mmap(const llama_mmap &) = delete;
|
||||
|
||||
#ifdef _POSIX_MAPPED_FILES
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
|
||||
size = file->size;
|
||||
int fd = fileno(file->fp);
|
||||
int flags = MAP_SHARED;
|
||||
#ifdef __linux__
|
||||
flags |= MAP_POPULATE;
|
||||
#endif
|
||||
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
|
||||
if (addr == MAP_FAILED) {
|
||||
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
|
||||
}
|
||||
|
||||
if (prefetch > 0) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
|
||||
fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
munmap(addr, size);
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
|
||||
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
|
||||
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
DWORD error = GetLastError();
|
||||
|
||||
if (hMapping == NULL) {
|
||||
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
|
||||
}
|
||||
|
||||
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
error = GetLastError();
|
||||
CloseHandle(hMapping);
|
||||
|
||||
if (addr == NULL) {
|
||||
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
|
||||
}
|
||||
|
||||
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
WIN32_MEMORY_RANGE_ENTRY range;
|
||||
range.VirtualAddress = addr;
|
||||
range.NumberOfBytes = (SIZE_T)size;
|
||||
if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
|
||||
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
|
||||
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr bool SUPPORTED = false;
|
||||
|
||||
llama_mmap(struct llama_file *, bool prefetch = true) {
|
||||
(void)prefetch;
|
||||
throw std::runtime_error(std::string("mmap not supported"));
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
// Represents some region of memory being locked using mlock or VirtualLock;
|
||||
// will automatically unlock on destruction.
|
||||
struct llama_mlock {
|
||||
void * addr = NULL;
|
||||
size_t size = 0;
|
||||
bool failed_already = false;
|
||||
|
||||
llama_mlock() {}
|
||||
llama_mlock(const llama_mlock &) = delete;
|
||||
|
||||
~llama_mlock() {
|
||||
if (size) {
|
||||
raw_unlock(addr, size);
|
||||
}
|
||||
}
|
||||
|
||||
void init(void * ptr) {
|
||||
LLAMA_ASSERT(addr == NULL && size == 0);
|
||||
addr = ptr;
|
||||
}
|
||||
|
||||
void grow_to(size_t target_size) {
|
||||
LLAMA_ASSERT(addr);
|
||||
if (failed_already) {
|
||||
return;
|
||||
}
|
||||
size_t granularity = lock_granularity();
|
||||
target_size = (target_size + granularity - 1) & ~(granularity - 1);
|
||||
if (target_size > size) {
|
||||
if (raw_lock((uint8_t *) addr + size, target_size - size)) {
|
||||
size = target_size;
|
||||
} else {
|
||||
failed_already = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _POSIX_MEMLOCK_RANGE
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
size_t lock_granularity() {
|
||||
return (size_t) sysconf(_SC_PAGESIZE);
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
||||
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
||||
#else
|
||||
#define MLOCK_SUGGESTION \
|
||||
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
||||
#endif
|
||||
|
||||
bool raw_lock(const void * addr, size_t size) {
|
||||
if (!mlock(addr, size)) {
|
||||
return true;
|
||||
} else {
|
||||
char* errmsg = std::strerror(errno);
|
||||
bool suggest = (errno == ENOMEM);
|
||||
|
||||
// Check if the resource limit is fine after all
|
||||
struct rlimit lock_limit;
|
||||
if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit))
|
||||
suggest = false;
|
||||
if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size))
|
||||
suggest = false;
|
||||
|
||||
fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
|
||||
size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#undef MLOCK_SUGGESTION
|
||||
|
||||
void raw_unlock(void * addr, size_t size) {
|
||||
if (munlock(addr, size)) {
|
||||
fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno));
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
size_t lock_granularity() {
|
||||
SYSTEM_INFO si;
|
||||
GetSystemInfo(&si);
|
||||
return (size_t) si.dwPageSize;
|
||||
}
|
||||
|
||||
bool raw_lock(void * ptr, size_t len) {
|
||||
for (int tries = 1; ; tries++) {
|
||||
if (VirtualLock(ptr, len)) {
|
||||
return true;
|
||||
}
|
||||
if (tries == 2) {
|
||||
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
|
||||
len, size, llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
// It failed but this was only the first try; increase the working
|
||||
// set size and try again.
|
||||
SIZE_T min_ws_size, max_ws_size;
|
||||
if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
|
||||
fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
// Per MSDN: "The maximum number of pages that a process can lock
|
||||
// is equal to the number of pages in its minimum working set minus
|
||||
// a small overhead."
|
||||
// Hopefully a megabyte is enough overhead:
|
||||
size_t increment = len + 1048576;
|
||||
// The minimum must be <= the maximum, so we need to increase both:
|
||||
min_ws_size += increment;
|
||||
max_ws_size += increment;
|
||||
if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
|
||||
fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void raw_unlock(void * ptr, size_t len) {
|
||||
if (!VirtualUnlock(ptr, len)) {
|
||||
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr bool SUPPORTED = false;
|
||||
|
||||
size_t lock_granularity() {
|
||||
return (size_t) 65536;
|
||||
}
|
||||
|
||||
bool raw_lock(const void * addr, size_t len) {
|
||||
fprintf(stderr, "warning: mlock not supported on this system\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
void raw_unlock(const void * addr, size_t len) {}
|
||||
#endif
|
||||
};
|
||||
|
||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
||||
struct llama_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
llama_buffer() = default;
|
||||
|
||||
void resize(size_t len) {
|
||||
delete[] addr;
|
||||
addr = new uint8_t[len];
|
||||
size = len;
|
||||
}
|
||||
|
||||
~llama_buffer() {
|
||||
delete[] addr;
|
||||
}
|
||||
|
||||
// disable copy and move
|
||||
llama_buffer(const llama_buffer&) = delete;
|
||||
llama_buffer(llama_buffer&&) = delete;
|
||||
llama_buffer& operator=(const llama_buffer&) = delete;
|
||||
llama_buffer& operator=(llama_buffer&&) = delete;
|
||||
};
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
struct llama_ctx_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
bool is_cuda;
|
||||
size_t size = 0;
|
||||
|
||||
llama_ctx_buffer() = default;
|
||||
|
||||
void resize(size_t size) {
|
||||
free();
|
||||
|
||||
addr = (uint8_t *) ggml_cuda_host_malloc(size);
|
||||
if (addr) {
|
||||
is_cuda = true;
|
||||
}
|
||||
else {
|
||||
// fall back to pageable memory
|
||||
addr = new uint8_t[size];
|
||||
is_cuda = false;
|
||||
}
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
void free() {
|
||||
if (addr) {
|
||||
if (is_cuda) {
|
||||
ggml_cuda_host_free(addr);
|
||||
}
|
||||
else {
|
||||
delete[] addr;
|
||||
}
|
||||
}
|
||||
addr = NULL;
|
||||
}
|
||||
|
||||
~llama_ctx_buffer() {
|
||||
free();
|
||||
}
|
||||
|
||||
// disable copy and move
|
||||
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
|
||||
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
|
||||
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
|
||||
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
|
||||
};
|
||||
#else
|
||||
typedef llama_buffer llama_ctx_buffer;
|
||||
#endif
|
||||
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
@ -1,16 +1,8 @@
|
||||
#ifndef LLAMA_H
|
||||
#define LLAMA_H
|
||||
|
||||
#include "ggml.h"
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||
#else
|
||||
#define LLAMA_MAX_DEVICES 1
|
||||
#endif // GGML_USE_CUBLAS
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
@ -27,25 +19,17 @@
|
||||
# define LLAMA_API
|
||||
#endif
|
||||
|
||||
#ifdef __GNUC__
|
||||
# define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
||||
#elif defined(_MSC_VER)
|
||||
# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
||||
#else
|
||||
# define DEPRECATED(func, hint) func
|
||||
#endif
|
||||
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
|
||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
|
||||
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 1
|
||||
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||
#endif
|
||||
#define LLAMA_FILE_VERSION 3
|
||||
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
|
||||
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -57,57 +41,10 @@ extern "C" {
|
||||
// TODO: show sample usage
|
||||
//
|
||||
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
typedef int llama_token;
|
||||
|
||||
enum llama_log_level {
|
||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||
LLAMA_LOG_LEVEL_WARN = 3,
|
||||
LLAMA_LOG_LEVEL_INFO = 4
|
||||
};
|
||||
|
||||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
};
|
||||
|
||||
enum llama_token_type {
|
||||
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
||||
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
||||
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
|
||||
LLAMA_TOKEN_TYPE_CONTROL = 3,
|
||||
LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
|
||||
LLAMA_TOKEN_TYPE_UNUSED = 5,
|
||||
LLAMA_TOKEN_TYPE_BYTE = 6,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
@ -123,152 +60,67 @@ extern "C" {
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
struct llama_context_params {
|
||||
uint32_t seed; // RNG seed, -1 for random
|
||||
int32_t n_ctx; // text context
|
||||
int32_t n_batch; // prompt processing batch size
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
int n_ctx; // text context
|
||||
int n_gpu_layers; // number of layers to store in VRAM
|
||||
int seed; // RNG seed, -1 for random
|
||||
|
||||
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
|
||||
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
float rope_freq_base; // RoPE base frequency
|
||||
float rope_freq_scale; // RoPE frequency scaling factor
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
|
||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
||||
bool low_vram; // if true, reduce VRAM usage at the cost of performance
|
||||
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||
bool vocab_only; // only load the vocabulary, no weights
|
||||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
// context pointer passed to the progress callback
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
// Signature for logging events
|
||||
// Note that text includes the new line character at the end for most events.
|
||||
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||
// if it exists.
|
||||
// It might not exist for progress report where '.' is output repeatedly.
|
||||
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
} llama_model_quantize_params;
|
||||
|
||||
// grammar types
|
||||
struct llama_grammar;
|
||||
|
||||
// grammar element type
|
||||
enum llama_gretype {
|
||||
// end of rule definition
|
||||
LLAMA_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
LLAMA_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
LLAMA_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
LLAMA_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
};
|
||||
|
||||
typedef struct llama_grammar_element {
|
||||
enum llama_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} llama_grammar_element;
|
||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||
|
||||
// performance timing information
|
||||
struct llama_timings {
|
||||
double t_start_ms;
|
||||
double t_end_ms;
|
||||
double t_load_ms;
|
||||
double t_sample_ms;
|
||||
double t_p_eval_ms;
|
||||
double t_eval_ms;
|
||||
|
||||
int32_t n_sample;
|
||||
int32_t n_p_eval;
|
||||
int32_t n_eval;
|
||||
};
|
||||
|
||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||
LLAMA_API bool llama_mmap_supported();
|
||||
LLAMA_API bool llama_mlock_supported();
|
||||
|
||||
// TODO: not great API - very likely to change
|
||||
// Initialize the llama + ggml backend
|
||||
// If numa is true, use NUMA optimizations
|
||||
// Call once at the start of the program
|
||||
LLAMA_API void llama_backend_init(bool numa);
|
||||
LLAMA_API void llama_init_backend();
|
||||
|
||||
// Call once at the end of the program - currently only used for MPI
|
||||
LLAMA_API void llama_backend_free(void);
|
||||
LLAMA_API int64_t llama_time_us();
|
||||
|
||||
LLAMA_API struct llama_model * llama_load_model_from_file(
|
||||
// Various functions for loading a ggml llama model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
LLAMA_API struct llama_context * llama_init_from_file(
|
||||
const char * path_model,
|
||||
struct llama_context_params params);
|
||||
|
||||
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||
|
||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||
struct llama_model * model,
|
||||
struct llama_context_params params);
|
||||
|
||||
// Frees all allocated memory
|
||||
LLAMA_API void llama_free(struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API int llama_max_devices (void);
|
||||
LLAMA_API bool llama_mmap_supported (void);
|
||||
LLAMA_API bool llama_mlock_supported(void);
|
||||
|
||||
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||
|
||||
// Get a string describing the model type
|
||||
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
|
||||
// Returns the total size of all the tensors in the model in bytes
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
// TODO: not great API - very likely to change
|
||||
// Returns 0 on success
|
||||
// nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
||||
LLAMA_API int llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
const llama_model_quantize_params * params);
|
||||
enum llama_ftype ftype,
|
||||
int nthread);
|
||||
|
||||
// Apply a LoRA adapter to a loaded model
|
||||
// path_base_model is the path to a higher quality model to use as a base for
|
||||
@ -276,24 +128,17 @@ extern "C" {
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
// will be applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
|
||||
LLAMA_API int llama_apply_lora_from_file(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora,
|
||||
const char * path_base_model,
|
||||
int n_threads),
|
||||
"please use llama_model_apply_lora_from_file instead");
|
||||
|
||||
LLAMA_API int llama_model_apply_lora_from_file(
|
||||
const struct llama_model * model,
|
||||
const char * path_lora,
|
||||
const char * path_base_model,
|
||||
int n_threads);
|
||||
int n_threads);
|
||||
|
||||
// Returns the number of tokens in the KV cache
|
||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||
|
||||
// Sets the current rng seed.
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
||||
|
||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||
// and kv_cache) - will often be smaller after compacting tokens
|
||||
@ -323,19 +168,21 @@ extern "C" {
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Same as llama_eval, but use float matrix input directly.
|
||||
LLAMA_API int llama_eval_embd(
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
// TODO: not sure if correct
|
||||
LLAMA_API int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads);
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
// Export a static computation graph for context of 511 and batch size of 1
|
||||
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
||||
// parameters here to keep things simple
|
||||
// IMPORTANT: do not use for anything else other than debugging and testing!
|
||||
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
|
||||
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||
|
||||
// Token logits obtained from the last call to llama_eval()
|
||||
// The logits for the last token are stored in the last row
|
||||
@ -348,75 +195,15 @@ extern "C" {
|
||||
// shape: [n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
||||
LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token);
|
||||
|
||||
LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
|
||||
|
||||
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
|
||||
|
||||
// Special tokens
|
||||
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence
|
||||
LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence
|
||||
LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line
|
||||
LLAMA_API llama_token llama_token_bos();
|
||||
LLAMA_API llama_token llama_token_eos();
|
||||
LLAMA_API llama_token llama_token_nl();
|
||||
|
||||
//
|
||||
// Tokenization
|
||||
//
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns a negative number on failure - the number of tokens that would have been returned
|
||||
LLAMA_API int llama_tokenize(
|
||||
struct llama_context * ctx,
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
LLAMA_API int llama_tokenize_with_model(
|
||||
const struct llama_model * model,
|
||||
const char * text,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
|
||||
// Token Id -> Piece.
|
||||
// Uses the vocabulary in the provided context.
|
||||
// Does not write null terminator to the buffer.
|
||||
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
|
||||
LLAMA_API int llama_token_to_piece(
|
||||
const struct llama_context * ctx,
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int length);
|
||||
|
||||
LLAMA_API int llama_token_to_piece_with_model(
|
||||
const struct llama_model * model,
|
||||
llama_token token,
|
||||
char * buf,
|
||||
int length);
|
||||
|
||||
//
|
||||
// Grammar
|
||||
//
|
||||
|
||||
LLAMA_API struct llama_grammar * llama_grammar_init(
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
size_t start_rule_index);
|
||||
|
||||
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||
|
||||
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
||||
|
||||
//
|
||||
// Sampling functions
|
||||
//
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||
@ -424,16 +211,6 @@ extern "C" {
|
||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
|
||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||
LLAMA_API void llama_sample_classifier_free_guidance(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale);
|
||||
|
||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
@ -450,9 +227,6 @@ extern "C" {
|
||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
|
||||
/// @details Apply constraints from grammar
|
||||
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
|
||||
|
||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
@ -474,60 +248,13 @@ extern "C" {
|
||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
|
||||
/// @details Accepts the sampled token into the grammar
|
||||
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||
|
||||
//
|
||||
// Beam search
|
||||
//
|
||||
|
||||
struct llama_beam_view {
|
||||
const llama_token * tokens;
|
||||
size_t n_tokens;
|
||||
float p; // Cumulative beam probability (renormalized relative to all beams)
|
||||
bool eob; // Callback should set this to true when a beam is at end-of-beam.
|
||||
};
|
||||
|
||||
// Passed to beam_search_callback function.
|
||||
// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
|
||||
// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
|
||||
// These pointers are valid only during the synchronous callback, so should not be saved.
|
||||
struct llama_beams_state {
|
||||
struct llama_beam_view * beam_views;
|
||||
size_t n_beams; // Number of elements in beam_views[].
|
||||
size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
|
||||
bool last_call; // True iff this is the last callback invocation.
|
||||
};
|
||||
|
||||
// Type of pointer to the beam_search_callback function.
|
||||
// void* callback_data is any custom data passed to llama_beam_search, that is subsequently
|
||||
// passed back to beam_search_callback. This avoids having to use global variables in the callback.
|
||||
typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
|
||||
|
||||
/// @details Deterministically returns entire sentence constructed by a beam search.
|
||||
/// @param ctx Pointer to the llama_context.
|
||||
/// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
|
||||
/// @param callback_data A pointer that is simply passed back to callback.
|
||||
/// @param n_beams Number of beams to use.
|
||||
/// @param n_past Number of tokens already evaluated.
|
||||
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
|
||||
/// @param n_threads Number of threads as passed to llama_eval().
|
||||
LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
|
||||
|
||||
// Performance information
|
||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
||||
|
||||
// Print system information
|
||||
LLAMA_API const char * llama_print_system_info(void);
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data);
|
||||
|
||||
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@ -537,11 +264,10 @@ extern "C" {
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
||||
|
||||
#endif // LLAMA_API_INTERNAL
|
||||
#endif
|
||||
|
||||
#endif // LLAMA_H
|
||||
|
0
examples/talk-llama/speak
Executable file → Normal file
0
examples/talk-llama/speak
Executable file → Normal file
@ -25,20 +25,6 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
|
||||
std::vector<char> result(8, 0);
|
||||
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
}
|
||||
|
||||
return std::string(result.data(), result.size());
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -47,14 +33,14 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
bool print_energy = false;
|
||||
bool no_timestamps = true;
|
||||
bool verbose_prompt = false;
|
||||
|
||||
std::string person = "Georgi";
|
||||
@ -249,7 +235,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// llama init
|
||||
|
||||
llama_backend_init(true);
|
||||
llama_init_backend();
|
||||
|
||||
auto lparams = llama_context_default_params();
|
||||
|
||||
@ -258,9 +244,7 @@ int main(int argc, char ** argv) {
|
||||
lparams.seed = 1;
|
||||
lparams.f16_kv = true;
|
||||
|
||||
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lparams);
|
||||
|
||||
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lparams);
|
||||
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
@ -283,6 +267,7 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
@ -293,6 +278,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.resume();
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool force_speak = false;
|
||||
|
||||
@ -527,7 +514,7 @@ int main(int argc, char ** argv) {
|
||||
//printf("\n---\n");
|
||||
//printf("resetting: '");
|
||||
//for (int i = 0; i < (int) embd.size(); i++) {
|
||||
// printf("%s", llama_token_to_piece(ctx_llama, embd[i]));
|
||||
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
|
||||
//}
|
||||
//printf("'\n");
|
||||
//printf("\n---\n");
|
||||
@ -595,7 +582,7 @@ int main(int argc, char ** argv) {
|
||||
auto logits = llama_get_logits(ctx_llama);
|
||||
auto n_vocab = llama_n_vocab(ctx_llama);
|
||||
|
||||
logits[llama_token_eos(ctx_llama)] = 0;
|
||||
logits[llama_token_eos()] = 0;
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
@ -606,13 +593,13 @@ int main(int argc, char ** argv) {
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// apply repeat penalty
|
||||
const float nl_logit = logits[llama_token_nl(ctx_llama)];
|
||||
const float nl_logit = logits[llama_token_nl()];
|
||||
|
||||
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
|
||||
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
||||
repeat_last_n, repeat_penalty);
|
||||
|
||||
logits[llama_token_nl(ctx_llama)] = nl_logit;
|
||||
logits[llama_token_nl()] = nl_logit;
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
@ -626,22 +613,22 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (id != llama_token_eos(ctx_llama)) {
|
||||
if (id != llama_token_eos()) {
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
|
||||
text_to_speak += llama_token_to_piece(ctx_llama, id);
|
||||
text_to_speak += llama_token_to_str(ctx_llama, id);
|
||||
|
||||
printf("%s", llama_token_to_piece(ctx_llama, id).c_str());
|
||||
printf("%s", llama_token_to_str(ctx_llama, id));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::string last_output;
|
||||
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
|
||||
last_output += llama_token_to_piece(ctx_llama, embd_inp[i]);
|
||||
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
|
||||
}
|
||||
last_output += llama_token_to_piece(ctx_llama, embd[0]);
|
||||
last_output += llama_token_to_str(ctx_llama, embd[0]);
|
||||
|
||||
for (std::string & antiprompt : antiprompts) {
|
||||
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
||||
@ -668,6 +655,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
|
||||
++n_iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
|
||||
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
||||
|
||||
```
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
|
||||
printf "\n"
|
||||
fi
|
||||
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
# actual run
|
||||
@ -83,13 +83,9 @@ for model in "${models[@]}"; do
|
||||
config="$config COREML"
|
||||
fi
|
||||
|
||||
if [[ $system_info == *"METAL = 1"* ]]; then
|
||||
config="$config METAL"
|
||||
fi
|
||||
|
||||
commit=$(git rev-parse --short HEAD)
|
||||
|
||||
if [ $ret -eq 0 ]; then
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
fi
|
||||
done
|
||||
|
222
extra/bench.py
222
extra/bench.py
@ -1,222 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import csv
|
||||
import wave
|
||||
import contextlib
|
||||
import argparse
|
||||
|
||||
|
||||
# Custom action to handle comma-separated list
|
||||
class ListAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, [int(val) for val in values.split(",")])
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Benchmark the speech recognition model")
|
||||
|
||||
# Define the argument to accept a list
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--threads",
|
||||
dest="threads",
|
||||
action=ListAction,
|
||||
default=[4],
|
||||
help="List of thread counts to benchmark (comma-separated, default: 4)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--processors",
|
||||
dest="processors",
|
||||
action=ListAction,
|
||||
default=[1],
|
||||
help="List of processor counts to benchmark (comma-separated, default: 1)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filename",
|
||||
type=str,
|
||||
default="./samples/jfk.wav",
|
||||
help="Relative path of the file to transcribe (default: ./samples/jfk.wav)",
|
||||
)
|
||||
|
||||
# Parse the command line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
sample_file = args.filename
|
||||
|
||||
threads = args.threads
|
||||
processors = args.processors
|
||||
|
||||
# Define the models, threads, and processor counts to benchmark
|
||||
models = [
|
||||
"ggml-tiny.en.bin",
|
||||
"ggml-tiny.bin",
|
||||
"ggml-base.en.bin",
|
||||
"ggml-base.bin",
|
||||
"ggml-small.en.bin",
|
||||
"ggml-small.bin",
|
||||
"ggml-medium.en.bin",
|
||||
"ggml-medium.bin",
|
||||
"ggml-large.bin",
|
||||
]
|
||||
|
||||
|
||||
metal_device = ""
|
||||
|
||||
# Initialize a dictionary to hold the results
|
||||
results = {}
|
||||
|
||||
gitHashHeader = "Commit"
|
||||
modelHeader = "Model"
|
||||
hardwareHeader = "Hardware"
|
||||
recordingLengthHeader = "Recording Length (seconds)"
|
||||
threadHeader = "Thread"
|
||||
processorCountHeader = "Processor Count"
|
||||
loadTimeHeader = "Load Time (ms)"
|
||||
sampleTimeHeader = "Sample Time (ms)"
|
||||
encodeTimeHeader = "Encode Time (ms)"
|
||||
decodeTimeHeader = "Decode Time (ms)"
|
||||
sampleTimePerRunHeader = "Sample Time per Run (ms)"
|
||||
encodeTimePerRunHeader = "Encode Time per Run (ms)"
|
||||
decodeTimePerRunHeader = "Decode Time per Run (ms)"
|
||||
totalTimeHeader = "Total Time (ms)"
|
||||
|
||||
|
||||
def check_file_exists(file: str) -> bool:
|
||||
return os.path.isfile(file)
|
||||
|
||||
|
||||
def get_git_short_hash() -> str:
|
||||
try:
|
||||
return (
|
||||
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
return ""
|
||||
|
||||
|
||||
def wav_file_length(file: str = sample_file) -> float:
|
||||
with contextlib.closing(wave.open(file, "r")) as f:
|
||||
frames = f.getnframes()
|
||||
rate = f.getframerate()
|
||||
duration = frames / float(rate)
|
||||
return duration
|
||||
|
||||
|
||||
def extract_metrics(output: str, label: str) -> tuple[float, float]:
|
||||
match = re.search(rf"{label} \s*=\s*(\d+\.\d+)\s*ms\s*/\s*(\d+)\s*runs", output)
|
||||
time = float(match.group(1)) if match else None
|
||||
runs = float(match.group(2)) if match else None
|
||||
return time, runs
|
||||
|
||||
|
||||
def extract_device(output: str) -> str:
|
||||
match = re.search(r"picking default device: (.*)", output)
|
||||
device = match.group(1) if match else "Not found"
|
||||
return device
|
||||
|
||||
|
||||
# Check if the sample file exists
|
||||
if not check_file_exists(sample_file):
|
||||
raise FileNotFoundError(f"Sample file {sample_file} not found")
|
||||
|
||||
recording_length = wav_file_length()
|
||||
|
||||
|
||||
# Check that all models exist
|
||||
# Filter out models from list that are not downloaded
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
if check_file_exists(f"models/{model}"):
|
||||
filtered_models.append(model)
|
||||
else:
|
||||
print(f"Model {model} not found, removing from list")
|
||||
|
||||
models = filtered_models
|
||||
|
||||
# Loop over each combination of parameters
|
||||
for model in filtered_models:
|
||||
for thread in threads:
|
||||
for processor_count in processors:
|
||||
# Construct the command to run
|
||||
cmd = f"./main -m models/{model} -t {thread} -p {processor_count} -f {sample_file}"
|
||||
# Run the command and get the output
|
||||
process = subprocess.Popen(
|
||||
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
|
||||
)
|
||||
|
||||
output = ""
|
||||
while process.poll() is None:
|
||||
output += process.stdout.read().decode()
|
||||
|
||||
# Parse the output
|
||||
load_time_match = re.search(r"load time\s*=\s*(\d+\.\d+)\s*ms", output)
|
||||
load_time = float(load_time_match.group(1)) if load_time_match else None
|
||||
|
||||
metal_device = extract_device(output)
|
||||
sample_time, sample_runs = extract_metrics(output, "sample time")
|
||||
encode_time, encode_runs = extract_metrics(output, "encode time")
|
||||
decode_time, decode_runs = extract_metrics(output, "decode time")
|
||||
|
||||
total_time_match = re.search(r"total time\s*=\s*(\d+\.\d+)\s*ms", output)
|
||||
total_time = float(total_time_match.group(1)) if total_time_match else None
|
||||
|
||||
model_name = model.replace("ggml-", "").replace(".bin", "")
|
||||
|
||||
print(
|
||||
f"Ran model={model_name} threads={thread} processor_count={processor_count}, took {total_time}ms"
|
||||
)
|
||||
# Store the times in the results dictionary
|
||||
results[(model_name, thread, processor_count)] = {
|
||||
loadTimeHeader: load_time,
|
||||
sampleTimeHeader: sample_time,
|
||||
encodeTimeHeader: encode_time,
|
||||
decodeTimeHeader: decode_time,
|
||||
sampleTimePerRunHeader: round(sample_time / sample_runs, 2),
|
||||
encodeTimePerRunHeader: round(encode_time / encode_runs, 2),
|
||||
decodeTimePerRunHeader: round(decode_time / decode_runs, 2),
|
||||
totalTimeHeader: total_time,
|
||||
}
|
||||
|
||||
# Write the results to a CSV file
|
||||
with open("benchmark_results.csv", "w", newline="") as csvfile:
|
||||
fieldnames = [
|
||||
gitHashHeader,
|
||||
modelHeader,
|
||||
hardwareHeader,
|
||||
recordingLengthHeader,
|
||||
threadHeader,
|
||||
processorCountHeader,
|
||||
loadTimeHeader,
|
||||
sampleTimeHeader,
|
||||
encodeTimeHeader,
|
||||
decodeTimeHeader,
|
||||
sampleTimePerRunHeader,
|
||||
encodeTimePerRunHeader,
|
||||
decodeTimePerRunHeader,
|
||||
totalTimeHeader,
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
writer.writeheader()
|
||||
|
||||
shortHash = get_git_short_hash()
|
||||
# Sort the results by total time in ascending order
|
||||
sorted_results = sorted(results.items(), key=lambda x: x[1].get(totalTimeHeader, 0))
|
||||
for params, times in sorted_results:
|
||||
row = {
|
||||
gitHashHeader: shortHash,
|
||||
modelHeader: params[0],
|
||||
hardwareHeader: metal_device,
|
||||
recordingLengthHeader: recording_length,
|
||||
threadHeader: params[1],
|
||||
processorCountHeader: params[2],
|
||||
}
|
||||
row.update(times)
|
||||
writer.writerow(row)
|
@ -5745,7 +5745,6 @@ inline void ggml_cuda_op_rope(
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
(void) src1_ddf_i;
|
||||
(void) i02;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
@ -5781,7 +5780,6 @@ inline void ggml_cuda_op_alibi(
|
||||
(void) src1;
|
||||
(void) src0_ddq_i;
|
||||
(void) src1_ddf_i;
|
||||
(void) i02;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
|
16
ggml-metal.m
16
ggml-metal.m
@ -78,7 +78,6 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
||||
@ -90,7 +89,6 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
||||
@ -239,7 +237,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
||||
@ -251,7 +248,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
||||
@ -313,7 +309,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
||||
@ -325,7 +320,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
||||
@ -891,7 +885,6 @@ void ggml_metal_graph_compute(
|
||||
ne00%32 == 0 &&
|
||||
ne11 > 1) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
||||
@ -926,18 +919,15 @@ void ggml_metal_graph_compute(
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
|
||||
nrows = 4;
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
} else if (false) {
|
||||
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
|
@ -523,79 +523,6 @@ kernel void kernel_mul_mat_q8_0_f32(
|
||||
}
|
||||
}
|
||||
|
||||
#define N_F32_F32 4
|
||||
|
||||
kernel void kernel_mul_mat_f32_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F32_F32;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F32_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const float4 * x4 = (device const float4 *)x;
|
||||
for (int row = 0; row < N_F32_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@ -1472,13 +1399,13 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01 [[buffer(4)]],
|
||||
constant int64_t & ne02 [[buffer(5)]],
|
||||
constant int64_t & ne10 [[buffer(9)]],
|
||||
constant int64_t & ne12 [[buffer(11)]],
|
||||
constant int64_t & ne0 [[buffer(15)]],
|
||||
constant int64_t & ne1 [[buffer(16)]],
|
||||
constant uint & gqa [[buffer(17)]],
|
||||
constant int64_t & ne01[[buffer(4)]],
|
||||
constant int64_t & ne02[[buffer(5)]],
|
||||
constant int64_t & ne10[[buffer(9)]],
|
||||
constant int64_t & ne12[[buffer(11)]],
|
||||
constant int64_t & ne0[[buffer(15)]],
|
||||
constant int64_t & ne1[[buffer(16)]],
|
||||
constant uint & gqa[[buffer(17)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
@ -2341,7 +2268,6 @@ typedef void (mat_mm_t)(
|
||||
constant uint & gqa,
|
||||
threadgroup uchar *, uint3, uint, uint);
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
|
72
ggml.c
72
ggml.c
@ -20119,27 +20119,27 @@ const char * gguf_type_name(enum gguf_type type) {
|
||||
return GGUF_TYPE_NAME[type];
|
||||
}
|
||||
|
||||
int gguf_get_version(const struct gguf_context * ctx) {
|
||||
int gguf_get_version(struct gguf_context * ctx) {
|
||||
return ctx->header.version;
|
||||
}
|
||||
|
||||
size_t gguf_get_alignment(const struct gguf_context * ctx) {
|
||||
size_t gguf_get_alignment(struct gguf_context * ctx) {
|
||||
return ctx->alignment;
|
||||
}
|
||||
|
||||
size_t gguf_get_data_offset(const struct gguf_context * ctx) {
|
||||
size_t gguf_get_data_offset(struct gguf_context * ctx) {
|
||||
return ctx->offset;
|
||||
}
|
||||
|
||||
void * gguf_get_data(const struct gguf_context * ctx) {
|
||||
void * gguf_get_data(struct gguf_context * ctx) {
|
||||
return ctx->data;
|
||||
}
|
||||
|
||||
int gguf_get_n_kv(const struct gguf_context * ctx) {
|
||||
int gguf_get_n_kv(struct gguf_context * ctx) {
|
||||
return ctx->header.n_kv;
|
||||
}
|
||||
|
||||
int gguf_find_key(const struct gguf_context * ctx, const char * key) {
|
||||
int gguf_find_key(struct gguf_context * ctx, const char * key) {
|
||||
// return -1 if key not found
|
||||
int keyfound = -1;
|
||||
|
||||
@ -20155,85 +20155,85 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) {
|
||||
return keyfound;
|
||||
}
|
||||
|
||||
const char * gguf_get_key(const struct gguf_context * ctx, int i) {
|
||||
const char * gguf_get_key(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].key.data;
|
||||
}
|
||||
|
||||
enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
|
||||
enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].type;
|
||||
}
|
||||
|
||||
enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
|
||||
enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.arr.type;
|
||||
}
|
||||
|
||||
const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
|
||||
const void * gguf_get_arr_data(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.arr.data;
|
||||
}
|
||||
|
||||
const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
|
||||
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
|
||||
struct gguf_kv * kv = &ctx->kv[key_id];
|
||||
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
|
||||
return str->data;
|
||||
}
|
||||
|
||||
int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
|
||||
int gguf_get_arr_n(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.arr.n;
|
||||
}
|
||||
|
||||
uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
|
||||
uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.uint8;
|
||||
}
|
||||
|
||||
int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
|
||||
int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.int8;
|
||||
}
|
||||
|
||||
uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
|
||||
uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.uint16;
|
||||
}
|
||||
|
||||
int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
|
||||
int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.int16;
|
||||
}
|
||||
|
||||
uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
|
||||
uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.uint32;
|
||||
}
|
||||
|
||||
int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
|
||||
int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.int32;
|
||||
}
|
||||
|
||||
float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
|
||||
float gguf_get_val_f32(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.float32;
|
||||
}
|
||||
|
||||
uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
|
||||
uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.uint64;
|
||||
}
|
||||
|
||||
int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
|
||||
int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.int64;
|
||||
}
|
||||
|
||||
double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
|
||||
double gguf_get_val_f64(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.float64;
|
||||
}
|
||||
|
||||
bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
|
||||
bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.bool_;
|
||||
}
|
||||
|
||||
const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
|
||||
const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
|
||||
return ctx->kv[i].value.str.data;
|
||||
}
|
||||
|
||||
int gguf_get_n_tensors(const struct gguf_context * ctx) {
|
||||
int gguf_get_n_tensors(struct gguf_context * ctx) {
|
||||
return ctx->header.n_tensors;
|
||||
}
|
||||
|
||||
int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
|
||||
int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
|
||||
// return -1 if tensor not found
|
||||
int tensorfound = -1;
|
||||
|
||||
@ -20249,11 +20249,11 @@ int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
|
||||
return tensorfound;
|
||||
}
|
||||
|
||||
size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
|
||||
size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) {
|
||||
return ctx->infos[i].offset;
|
||||
}
|
||||
|
||||
char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
|
||||
char * gguf_get_tensor_name(struct gguf_context * ctx, int i) {
|
||||
return ctx->infos[i].name.data;
|
||||
}
|
||||
|
||||
@ -20536,7 +20536,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
|
||||
buf->offset += el_size;
|
||||
}
|
||||
|
||||
static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
|
||||
static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
|
||||
// write header
|
||||
gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
|
||||
gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
|
||||
@ -20651,7 +20651,7 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf *
|
||||
}
|
||||
}
|
||||
|
||||
void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
|
||||
void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) {
|
||||
FILE * file = fopen(fname, "wb");
|
||||
if (!file) {
|
||||
GGML_ASSERT(false && "failed to open file for writing");
|
||||
@ -20668,7 +20668,7 @@ void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
|
||||
fclose(file);
|
||||
}
|
||||
|
||||
size_t gguf_get_meta_size(const struct gguf_context * ctx) {
|
||||
size_t gguf_get_meta_size(struct gguf_context * ctx) {
|
||||
// no allocs - only compute size
|
||||
struct gguf_buf buf = gguf_buf_init(0);
|
||||
|
||||
@ -20677,7 +20677,7 @@ size_t gguf_get_meta_size(const struct gguf_context * ctx) {
|
||||
return buf.offset;
|
||||
}
|
||||
|
||||
void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
|
||||
void gguf_get_meta_data(struct gguf_context * ctx, void * data) {
|
||||
struct gguf_buf buf = gguf_buf_init(16*1024);
|
||||
|
||||
gguf_write_to_buf(ctx, &buf, true);
|
||||
@ -20753,14 +20753,6 @@ int ggml_cpu_has_arm_fma(void) {
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_metal(void) {
|
||||
#if defined(GGML_USE_METAL)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_f16c(void) {
|
||||
#if defined(__F16C__)
|
||||
return 1;
|
||||
|
72
ggml.h
72
ggml.h
@ -195,14 +195,6 @@
|
||||
# define GGML_DEPRECATED(func, hint) func
|
||||
#endif
|
||||
|
||||
#ifndef __GNUC__
|
||||
# define GGML_ATTRIBUTE_FORMAT(...)
|
||||
#elif defined(__MINGW32__)
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||
#else
|
||||
# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
@ -693,7 +685,6 @@ extern "C" {
|
||||
|
||||
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
||||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
||||
|
||||
//
|
||||
@ -1875,39 +1866,39 @@ extern "C" {
|
||||
|
||||
GGML_API const char * gguf_type_name(enum gguf_type type);
|
||||
|
||||
GGML_API int gguf_get_version (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
||||
GGML_API void * gguf_get_data (const struct gguf_context * ctx);
|
||||
GGML_API int gguf_get_version (struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_alignment (struct gguf_context * ctx);
|
||||
GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
|
||||
GGML_API void * gguf_get_data (struct gguf_context * ctx);
|
||||
|
||||
GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
|
||||
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
|
||||
GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
|
||||
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
|
||||
|
||||
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
|
||||
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
|
||||
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
|
||||
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
|
||||
|
||||
// results are undefined if the wrong type is used for the key
|
||||
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
|
||||
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
|
||||
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
|
||||
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
|
||||
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
|
||||
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
|
||||
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
|
||||
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
|
||||
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
|
||||
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
|
||||
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
|
||||
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
|
||||
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
|
||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
||||
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
|
||||
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
|
||||
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
|
||||
GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i);
|
||||
GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i);
|
||||
GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i);
|
||||
GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i);
|
||||
GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i);
|
||||
GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i);
|
||||
GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i);
|
||||
GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i);
|
||||
GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
|
||||
GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i);
|
||||
GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);
|
||||
GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);
|
||||
|
||||
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
|
||||
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
|
||||
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
|
||||
GGML_API int gguf_get_n_tensors (struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name);
|
||||
GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);
|
||||
GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i);
|
||||
|
||||
// overrides existing values or adds a new one
|
||||
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
||||
@ -1952,11 +1943,11 @@ extern "C" {
|
||||
//
|
||||
|
||||
// write the entire context to a binary file
|
||||
GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
|
||||
GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);
|
||||
|
||||
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
||||
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
||||
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
||||
GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);
|
||||
GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data);
|
||||
|
||||
//
|
||||
// system info
|
||||
@ -1970,7 +1961,6 @@ extern "C" {
|
||||
GGML_API int ggml_cpu_has_fma (void);
|
||||
GGML_API int ggml_cpu_has_neon (void);
|
||||
GGML_API int ggml_cpu_has_arm_fma (void);
|
||||
GGML_API int ggml_cpu_has_metal (void);
|
||||
GGML_API int ggml_cpu_has_f16c (void);
|
||||
GGML_API int ggml_cpu_has_fp16_va (void);
|
||||
GGML_API int ggml_cpu_has_wasm_simd (void);
|
||||
|
@ -1,117 +0,0 @@
|
||||
import argparse
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py')
|
||||
whisper_to_coreml = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(whisper_to_coreml)
|
||||
|
||||
from whisper import load_model
|
||||
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from transformers import WhisperForConditionalGeneration
|
||||
from huggingface_hub import metadata_update
|
||||
|
||||
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
||||
WHISPER_MAPPING = {
|
||||
"layers": "blocks",
|
||||
"fc1": "mlp.0",
|
||||
"fc2": "mlp.2",
|
||||
"final_layer_norm": "mlp_ln",
|
||||
"layers": "blocks",
|
||||
".self_attn.q_proj": ".attn.query",
|
||||
".self_attn.k_proj": ".attn.key",
|
||||
".self_attn.v_proj": ".attn.value",
|
||||
".self_attn_layer_norm": ".attn_ln",
|
||||
".self_attn.out_proj": ".attn.out",
|
||||
".encoder_attn.q_proj": ".cross_attn.query",
|
||||
".encoder_attn.k_proj": ".cross_attn.key",
|
||||
".encoder_attn.v_proj": ".cross_attn.value",
|
||||
".encoder_attn_layer_norm": ".cross_attn_ln",
|
||||
".encoder_attn.out_proj": ".cross_attn.out",
|
||||
"decoder.layer_norm.": "decoder.ln.",
|
||||
"encoder.layer_norm.": "encoder.ln_post.",
|
||||
"embed_tokens": "token_embedding",
|
||||
"encoder.embed_positions.weight": "encoder.positional_embedding",
|
||||
"decoder.embed_positions.weight": "decoder.positional_embedding",
|
||||
"layer_norm": "ln_post",
|
||||
}
|
||||
|
||||
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
||||
def rename_keys(s_dict):
|
||||
keys = list(s_dict.keys())
|
||||
for key in keys:
|
||||
new_key = key
|
||||
for k, v in WHISPER_MAPPING.items():
|
||||
if k in key:
|
||||
new_key = new_key.replace(k, v)
|
||||
|
||||
print(f"{key} -> {new_key}")
|
||||
|
||||
s_dict[new_key] = s_dict.pop(key)
|
||||
return s_dict
|
||||
|
||||
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
|
||||
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
||||
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
||||
config = transformer_model.config
|
||||
|
||||
# first build dims
|
||||
dims = {
|
||||
'n_mels': config.num_mel_bins,
|
||||
'n_vocab': config.vocab_size,
|
||||
'n_audio_ctx': config.max_source_positions,
|
||||
'n_audio_state': config.d_model,
|
||||
'n_audio_head': config.encoder_attention_heads,
|
||||
'n_audio_layer': config.encoder_layers,
|
||||
'n_text_ctx': config.max_target_positions,
|
||||
'n_text_state': config.d_model,
|
||||
'n_text_head': config.decoder_attention_heads,
|
||||
'n_text_layer': config.decoder_layers
|
||||
}
|
||||
|
||||
state_dict = deepcopy(transformer_model.model.state_dict())
|
||||
state_dict = rename_keys(state_dict)
|
||||
|
||||
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
|
||||
|
||||
# Ported from models/convert-whisper-to-coreml.py
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
|
||||
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
|
||||
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
||||
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
||||
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
pt_target_path = f"models/hf-{args.model_name}.pt"
|
||||
convert_hf_whisper(args.model_path, pt_target_path)
|
||||
|
||||
whisper = load_model(pt_target_path).cpu()
|
||||
hparams = whisper.dims
|
||||
print(hparams)
|
||||
|
||||
if args.optimize_ane:
|
||||
whisperANE = whisper_to_coreml.WhisperANE(hparams).eval()
|
||||
whisperANE.load_state_dict(whisper.state_dict())
|
||||
|
||||
encoder = whisperANE.encoder
|
||||
decoder = whisperANE.decoder
|
||||
else:
|
||||
encoder = whisper.encoder
|
||||
decoder = whisper.decoder
|
||||
|
||||
# Convert encoder
|
||||
encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize)
|
||||
encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage")
|
||||
|
||||
if args.encoder_only is False:
|
||||
# Convert decoder
|
||||
decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize)
|
||||
decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage")
|
||||
|
||||
print("done converting")
|
@ -29,7 +29,7 @@ def convert_encoder(hparams, encoder, mname):
|
||||
|
||||
# use model optimizer to convert onnx to OpenVINO IR format
|
||||
encoder_model = mo.convert_model(onnx_path, compress_to_fp16=True)
|
||||
serialize(encoder_model, xml_path=os.path.join(os.path.dirname(__file__),"ggml-" + mname + "-encoder-openvino.xml"))
|
||||
serialize(encoder_model, xml_path='ggml-' + mname + '-encoder-openvino.xml')
|
||||
|
||||
#cleanup
|
||||
if os.path.isdir(onnx_folder):
|
||||
|
@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
|
||||
goto :eof
|
||||
)
|
||||
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Start-BitsTransfer -Source https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -Destination ggml-%model%.bin"
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo Failed to download ggml model %model%
|
||||
|
@ -1,15 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Usage: ./generate-coreml-model.sh <model-name>
|
||||
if [ $# -eq 0 ]; then
|
||||
echo "No model name supplied"
|
||||
echo "Usage for Whisper models: ./generate-coreml-model.sh <model-name>"
|
||||
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
|
||||
exit 1
|
||||
elif [[ "$1" == "-h5" && $# != 3 ]]; then
|
||||
echo "No model name and model path supplied for a HuggingFace model"
|
||||
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
|
||||
exit 1
|
||||
if [ $# -eq 0 ]
|
||||
then
|
||||
echo "No model name supplied"
|
||||
echo "Usage: ./generate-coreml-model.sh <model-name>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mname="$1"
|
||||
@ -17,14 +13,7 @@ mname="$1"
|
||||
wd=$(dirname "$0")
|
||||
cd "$wd/../"
|
||||
|
||||
if [[ $mname == "-h5" ]]; then
|
||||
mname="$2"
|
||||
mpath="$3"
|
||||
echo $mpath
|
||||
python3 models/convert-h5-to-coreml.py --model-name $mname --model-path $mpath --encoder-only True
|
||||
else
|
||||
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
|
||||
fi
|
||||
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
|
||||
|
||||
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
|
||||
rm -rf models/ggml-${mname}-encoder.mlmodelc
|
||||
|
73
whisper.cpp
73
whisper.cpp
@ -125,17 +125,9 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
||||
// ggml helpers
|
||||
//
|
||||
|
||||
static void ggml_graph_compute_helper(
|
||||
std::vector<uint8_t> & buf,
|
||||
ggml_cgraph * graph,
|
||||
int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
plan.abort_callback = abort_callback;
|
||||
plan.abort_callback_data = abort_callback_data;
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
@ -1930,9 +1922,7 @@ static bool whisper_encode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
const int mel_offset,
|
||||
const int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
const int n_threads) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
// conv
|
||||
@ -1946,7 +1936,7 @@ static bool whisper_encode_internal(
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
if (!whisper_encode_external(wstate)) {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1965,10 +1955,10 @@ static bool whisper_encode_internal(
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1987,10 +1977,10 @@ static bool whisper_encode_internal(
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -2356,9 +2346,7 @@ static bool whisper_decode_internal(
|
||||
const whisper_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
const int n_threads) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const auto & model = wctx.model;
|
||||
@ -2387,10 +2375,10 @@ static bool whisper_decode_internal(
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -3302,7 +3290,7 @@ int whisper_set_mel(
|
||||
}
|
||||
|
||||
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
||||
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
||||
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
@ -3311,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
||||
}
|
||||
|
||||
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
||||
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
||||
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
@ -3322,7 +3310,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
||||
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
||||
const int selected_decoder_id = 0;
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -3339,7 +3327,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -3681,7 +3669,6 @@ const char * whisper_print_system_info(void) {
|
||||
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
||||
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
||||
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
||||
s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
|
||||
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
||||
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
||||
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
||||
@ -3773,9 +3760,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.encoder_begin_callback =*/ nullptr,
|
||||
/*.encoder_begin_callback_user_data =*/ nullptr,
|
||||
|
||||
/*.abort_callback =*/ nullptr,
|
||||
/*.abort_callback_user_data =*/ nullptr,
|
||||
|
||||
/*.logits_filter_callback =*/ nullptr,
|
||||
/*.logits_filter_callback_user_data =*/ nullptr,
|
||||
};
|
||||
@ -3942,7 +3926,6 @@ static void whisper_process_logits(
|
||||
// suppress task tokens
|
||||
logits[vocab.token_translate] = -INFINITY;
|
||||
logits[vocab.token_transcribe] = -INFINITY;
|
||||
logits[vocab.token_prev] = -INFINITY;
|
||||
|
||||
if (params.logits_filter_callback) {
|
||||
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
||||
@ -4536,7 +4519,7 @@ int whisper_full_with_state(
|
||||
|
||||
// initial prompt
|
||||
if (!params.prompt_tokens && params.initial_prompt) {
|
||||
prompt_tokens.resize(2048);
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
||||
params.prompt_tokens = prompt_tokens.data();
|
||||
params.prompt_n_tokens = prompt_tokens.size();
|
||||
@ -4561,7 +4544,6 @@ int whisper_full_with_state(
|
||||
|
||||
// these tokens determine the task that will be performed
|
||||
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
||||
|
||||
if (whisper_is_multilingual(ctx)) {
|
||||
const int lang_id = whisper_lang_id(params.language);
|
||||
state->lang_id = lang_id;
|
||||
@ -4573,17 +4555,6 @@ int whisper_full_with_state(
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
|
||||
|
||||
// distilled models require the "no_timestamps" token
|
||||
// TODO: add input parameter (#1229)
|
||||
if (is_distil) {
|
||||
log("%s: using distilled model - forcing no_timestamps\n", __func__);
|
||||
prompt_init.push_back(whisper_token_not(ctx));
|
||||
}
|
||||
}
|
||||
|
||||
int seek = seek_start;
|
||||
|
||||
std::vector<whisper_token> prompt;
|
||||
@ -4622,7 +4593,7 @@ int whisper_full_with_state(
|
||||
}
|
||||
|
||||
// encode audio features starting at offset seek
|
||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
||||
log("%s: failed to encode\n", __func__);
|
||||
return -6;
|
||||
}
|
||||
@ -4705,7 +4676,7 @@ int whisper_full_with_state(
|
||||
}
|
||||
WHISPER_PRINT_DEBUG("\n\n");
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
||||
log("%s: failed to decode\n", __func__);
|
||||
return -7;
|
||||
}
|
||||
@ -4929,7 +4900,7 @@ int whisper_full_with_state(
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
||||
log("%s: failed to decode\n", __func__);
|
||||
return -8;
|
||||
}
|
||||
@ -5298,10 +5269,6 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
|
||||
return ctx->state->result_all[i_segment].t1;
|
||||
}
|
||||
|
||||
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
|
||||
return state->result_all[i_segment].speaker_turn_next;
|
||||
}
|
||||
|
||||
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
|
||||
return ctx->state->result_all[i_segment].speaker_turn_next;
|
||||
}
|
||||
@ -5501,12 +5468,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
||||
double tsum = 0.0;
|
||||
|
||||
// heat-up
|
||||
ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
|
||||
ggml_graph_compute_helper(work, &gf, n_threads);
|
||||
|
||||
for (int i = 0; i < n_max; ++i) {
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
|
||||
ggml_graph_compute_helper(work, &gf, n_threads);
|
||||
|
||||
const int64_t t1 = ggml_time_us();
|
||||
|
||||
|
10
whisper.h
10
whisper.h
@ -334,11 +334,6 @@ extern "C" {
|
||||
// If it returns false, the computation is aborted
|
||||
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
||||
|
||||
// Abort callback
|
||||
// If not NULL, called before ggml computation
|
||||
// If it returns true, the computation is aborted
|
||||
typedef bool (*whisper_abort_callback)(void * user_data);
|
||||
|
||||
// Logits filter callback
|
||||
// Can be used to modify the logits before sampling
|
||||
// If not NULL, called after applying temperature to logits
|
||||
@ -433,10 +428,6 @@ extern "C" {
|
||||
whisper_encoder_begin_callback encoder_begin_callback;
|
||||
void * encoder_begin_callback_user_data;
|
||||
|
||||
// called each time before ggml computation starts
|
||||
whisper_abort_callback abort_callback;
|
||||
void * abort_callback_user_data;
|
||||
|
||||
// called by each decoder to filter obtained logits
|
||||
whisper_logits_filter_callback logits_filter_callback;
|
||||
void * logits_filter_callback_user_data;
|
||||
@ -494,7 +485,6 @@ extern "C" {
|
||||
|
||||
// Get whether the next segment is predicted as a speaker turn
|
||||
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
|
||||
WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment);
|
||||
|
||||
// Get the text of the specified segment
|
||||
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
|
||||
|
Reference in New Issue
Block a user