Compare commits

...

19 Commits

Author SHA1 Message Date
13c5446759 Update ggml-cuda/mmvq.cu
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-06-11 17:37:32 +03:00
9df6298a91 cuda : fix bounds check for src0 rows in MMVQ kernel 2024-06-11 11:30:12 +03:00
20c542c713 whisper : auto-grow working areas for mel_calc_cuda (#2227)
* whisper : auto-grow working areas for mel_calc_cuda, fixes #2226

* whisper : only calculate mel spectrogram on GPU if audio is <= 5 min
2024-06-10 21:51:32 +03:00
c2bdb960cd whisper : free whisper_mel instances (#2220) 2024-06-10 11:00:15 +03:00
87acd6d629 whisper : whisper_state/backend fixes (#2217)
* whisper : fixes

* ci : WHISPER_CUBLAS -> WHISPER_CUDA
2024-06-06 18:51:36 +03:00
f842d31171 whisper : calculate mel spectrogram directly into a ggml_tensor (#2208)
* whisper : calculate mel spectrogram directly into a ggml_tensor

* whisper : remove unused temp buffer from state

* whisper : fix not initializing wstate.embd_enc
2024-06-06 16:20:46 +03:00
ffef323c4c whisper : add CUDA-specific computation mel spectrograms (#2206)
* whisper : use polymorphic class to calculate mel spectrogram

* whisper : add cuda-specific mel spectrogram calculation

* whisper : conditionally compile cufftGetErrorString to avoid warnings

* build : add new files to makefile

* ruby : add new files to conf script

* build : fix typo in makefile

* whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda
2024-06-04 09:32:23 +03:00
af5833e298 whisper : remove speed_up and phase_vocoder* functions (#2198)
* whisper : fix cast warning

* whisper : remove phase_vocoder functions, ref #2195

* whisper : remove speed_up from whisper_full_params, closes #2195
2024-05-31 11:37:29 +03:00
b87494bb8f readme : add conan badge (#2196)
* Add conan badge

* Fix markdown formating
2024-05-30 15:43:28 +03:00
ad130431aa readme : add install instructions for Conan (#2189) 2024-05-30 15:06:15 +03:00
e130b66642 whisper: use global cache for sin/cos vals and Hann window (#2194)
- also rename Hanning to Hann as it's named after Julius von Hann
 as per Wikipedia
2024-05-29 19:09:21 +03:00
c7b6988678 release : v1.6.2 2024-05-27 10:35:09 +03:00
05042a782d Revert "whisper : remove extra backend instance (huh?)" (#2182)
This reverts commit 4caa64b73e.
2024-05-27 10:20:25 +03:00
a7dc2aab16 server : fix typo (#2181)
A simple comment typo, PR can be dismissed
2024-05-25 10:46:22 +03:00
22d46b7ba4 ruby : update bindings (#2154)
* update library files

* update whispercpp

* not needed for gem
2024-05-22 23:02:52 +03:00
c10db6ea28 release : v1.6.1 2024-05-21 18:44:37 +03:00
1b51fdf170 examples : add support for decoding input with ffmpeg (Linux) (#2133)
- search for ffmpeg libs/headers at cmake time
- added ffmpeg-transcode.cpp into libcommon if ffmpeg on
- hooked ffmpeg trancoding in common read_wav(...)
- passed test:
./main -m ggml-base.en.bin -f samples/jfk.mp3
2024-05-21 18:31:41 +03:00
adee3f9c1f node : add flash_attn param (#2170) 2024-05-20 09:08:48 +03:00
4798be1f9a ci: Update build.yml to suppress warnings about node.js versions (#2166)
* Update actions to suppress warnings about old node.js

https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/

* Update actions/upload-artifact, specify android cmdline-tools-version

* Use java 20

gradle 8.1 complains against 21
https://docs.gradle.org/current/userguide/compatibility.html
2024-05-19 11:49:26 +03:00
56 changed files with 11668 additions and 1932 deletions

View File

@ -15,10 +15,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }} - name: Build ${{ matrix.arch }}
run: | run: |
@ -36,7 +36,7 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Dependencies - name: Dependencies
run: | run: |
@ -53,10 +53,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Build - name: Build
uses: cross-platform-actions/action@v0.15.0 uses: cross-platform-actions/action@v0.24.0
with: with:
operating_system: freebsd operating_system: freebsd
version: '13.2' version: '13.2'
@ -77,10 +77,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }} - name: Build ${{ matrix.arch }}
run: | run: |
@ -105,10 +105,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }} - name: Build ${{ matrix.arch }}
run: | run: |
@ -133,10 +133,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v3
- name: Build ${{ matrix.arch }} - name: Build ${{ matrix.arch }}
run: | run: |
@ -165,7 +165,7 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: add oneAPI to apt - name: add oneAPI to apt
shell: bash shell: bash
@ -189,7 +189,7 @@ jobs:
- name: Clone - name: Clone
id: checkout id: checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -215,7 +215,7 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: add oneAPI to apt - name: add oneAPI to apt
shell: bash shell: bash
@ -239,7 +239,7 @@ jobs:
- name: Clone - name: Clone
id: checkout id: checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Build - name: Build
id: cmake_build id: cmake_build
@ -262,7 +262,7 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Setup ${{ matrix.sys }} - name: Setup ${{ matrix.sys }}
uses: msys2/setup-msys2@v2 uses: msys2/setup-msys2@v2
@ -328,10 +328,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Add msbuild to PATH - name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1 uses: microsoft/setup-msbuild@v2
- name: Fetch SDL2 and set SDL2_DIR - name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
@ -356,14 +356,14 @@ jobs:
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload dll - name: Upload dll
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.jnaPath }}_whisper.dll name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload binaries - name: Upload binaries
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1 uses: actions/upload-artifact@v4
with: with:
name: whisper-bin-${{ matrix.arch }} name: whisper-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
@ -392,10 +392,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Add msbuild to PATH - name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1 uses: microsoft/setup-msbuild@v2
- name: Fetch OpenBLAS - name: Fetch OpenBLAS
if: matrix.blas == 'ON' if: matrix.blas == 'ON'
@ -453,7 +453,7 @@ jobs:
- name: Upload binaries - name: Upload binaries
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1 uses: actions/upload-artifact@v4
with: with:
name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }} name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
@ -476,14 +476,14 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Add msbuild to PATH - name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1 uses: microsoft/setup-msbuild@v2
- name: Install CUDA Toolkit - name: Install CUDA Toolkit
id: cuda-toolkit id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.11 uses: Jimver/cuda-toolkit@v0.2.15
with: with:
cuda: '${{ matrix.cuda-toolkit }}' cuda: '${{ matrix.cuda-toolkit }}'
@ -498,7 +498,7 @@ jobs:
run: > run: >
cmake -S . -B ./build -A ${{ matrix.arch }} cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=${{ matrix.cublas }} -DWHISPER_CUDA=${{ matrix.cublas }}
-DWHISPER_SDL2=${{ matrix.sdl2 }} -DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build ${{ matrix.cuda-toolkit }} - name: Build ${{ matrix.cuda-toolkit }}
@ -519,7 +519,7 @@ jobs:
- name: Upload binaries - name: Upload binaries
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1 uses: actions/upload-artifact@v4
with: with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }} name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
@ -533,10 +533,10 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Setup emsdk - name: Setup emsdk
uses: mymindstorm/setup-emsdk@v12 uses: mymindstorm/setup-emsdk@v14
- name: Verify - name: Verify
run: emcc -v run: emcc -v
@ -555,7 +555,7 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Configure - name: Configure
run: | run: |
@ -573,24 +573,24 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
path: whisper path: whisper
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
repository: ggerganov/ggml repository: ggerganov/ggml
path: ggml path: ggml
- name: Install Java - name: Install Java
uses: actions/setup-java@v3 uses: actions/setup-java@v4
with: with:
distribution: zulu distribution: zulu
java-version: 17 java-version: 21
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v2 uses: android-actions/setup-android@v3
- name: Build - name: Build
run: | run: |
@ -608,20 +608,19 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: set up JDK 11 - name: set up JDK 11
uses: actions/setup-java@v3 uses: actions/setup-java@v4
with: with:
java-version: '11' java-version: '11'
distribution: 'temurin' distribution: 'temurin'
cache: gradle cache: gradle
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v2 uses: android-actions/setup-android@v3
with: with:
api-level: 30 cmdline-tools-version: 9.0
build-tools-version: 30.0.3
- name: Build - name: Build
run: | run: |
@ -633,15 +632,16 @@ jobs:
needs: [ 'windows' ] needs: [ 'windows' ]
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Install Java - name: Install Java
uses: actions/setup-java@v1 uses: actions/setup-java@v4
with: with:
java-version: 17 distribution: zulu
java-version: 20
- name: Download Windows lib - name: Download Windows lib
uses: actions/download-artifact@v3 uses: actions/download-artifact@v4
with: with:
name: win32-x86-64_whisper.dll name: win32-x86-64_whisper.dll
path: bindings/java/build/generated/resources/main/win32-x86-64 path: bindings/java/build/generated/resources/main/win32-x86-64
@ -654,7 +654,7 @@ jobs:
./gradlew build ./gradlew build
- name: Upload jar - name: Upload jar
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: whispercpp.jar name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar path: bindings/java/build/libs/whispercpp-*.jar
@ -676,7 +676,7 @@ jobs:
steps: steps:
- name: Clone - name: Clone
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Test quantize - name: Test quantize
run: | run: |

View File

@ -3,7 +3,7 @@ cmake_minimum_required (VERSION 3.5)
# Allow for the creation of solution folders. # Allow for the creation of solution folders.
set_property(GLOBAL PROPERTY USE_FOLDERS ON) set_property(GLOBAL PROPERTY USE_FOLDERS ON)
project(whisper.cpp VERSION 1.6.0) project(whisper.cpp VERSION 1.6.2)
set(SOVERSION 1) set(SOVERSION 1)
# Add path to modules # Add path to modules
@ -59,6 +59,10 @@ option(WHISPER_BUILD_EXAMPLES "whisper: build examples" ${WHISPER_STANDA
option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) option(WHISPER_SDL2 "whisper: support for libSDL2" OFF)
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF)
endif()
option(WHISPER_NO_AVX "whisper: disable AVX" OFF) option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF) option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_AVX512 "whisper: disable AVX512" ON) option(WHISPER_NO_AVX512 "whisper: disable AVX512" ON)
@ -125,6 +129,26 @@ else()
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
endif() endif()
if (WHISPER_FFMPEG)
# As of cmake 3.27, there is no official cmake support for FindFFmpeg.
# Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder:
# whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE
# libswresample performs highly optimized audio resampling, rematrixing and sample format conversion operations
# libavcodec provides a generic encoding/decoding framework and contains multiple decoders and encoders for audio, video and subtitle streams, and several bitstream filters.
# libavformat provides a generic framework for multiplexing and demultiplexing (muxing and demuxing) audio, video and subtitle streams.
find_package(FFmpeg REQUIRED)
if (NOT ${FFMPEG_FOUND})
message(FATAL_ERROR "Cannot find ffmpeg libs/headers")
endif()
message(STATUS "Found ffmpeg libs: ${FFMPEG_LIBRARIES}")
message(STATUS "Found ffmpeg headers in: ${FFMPEG_INCLUDE_DIRS}")
message(STATUS "ffmpeg definitions: ${FFMPEG_DEFINITIONS}")
message(STATUS "Found avformat ${AVFORMAT_VERSION}")
include_directories(${FFMPEG_INCLUDE_DIRS})
add_compile_definitions(WHISPER_FFMPEG)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${FFMPEG_LIBRARIES})
endif()
# on APPLE # on APPLE
if (APPLE) if (APPLE)
# include Accelerate framework # include Accelerate framework
@ -340,12 +364,12 @@ if (WHISPER_CUDA)
if (WHISPER_STATIC) if (WHISPER_STATIC)
if (WIN32) if (WIN32)
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
else () else ()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
endif() endif()
else() else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
endif() endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
@ -655,6 +679,10 @@ add_library(${TARGET}
whisper.cpp whisper.cpp
) )
if (WHISPER_CUDA)
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
endif()
include_directories ( include_directories (
. .
) )

View File

@ -286,8 +286,8 @@ ifdef WHISPER_CUDA
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
WHISPER_OBJ += ggml-cuda.o WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
NVCC = nvcc NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG) NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif endif
whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
ifdef WHISPER_HIPBLAS ifdef WHISPER_HIPBLAS
ROCM_PATH ?= /opt/rocm ROCM_PATH ?= /opt/rocm
HIPCC ?= $(ROCM_PATH)/bin/hipcc HIPCC ?= $(ROCM_PATH)/bin/hipcc
@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ifndef WHISPER_COREML ifndef WHISPER_COREML

View File

@ -4,9 +4,10 @@
[![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions) [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.6.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -502,6 +503,16 @@ docker run -it --rm \
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav" whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
``` ```
## Installing with Conan
You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command:
```
conan install --requires="whisper-cpp/[*]" --build=missing
```
For detailed instructions on how to use Conan, please refer to the [Conan documentation](https://docs.conan.io/2/).
## Limitations ## Limitations
- Inference only - Inference only
@ -710,7 +721,7 @@ The [main](examples/main) example provides support for output of karaoke-style m
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script. currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed. This requires to have `ffmpeg` installed.
Here are a few *"typical"* examples: Here are a few _"typical"_ examples:
```bash ```bash
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts

View File

@ -68,10 +68,6 @@ func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String()) return strings.ToLower(flags.Lookup("out").Value.String())
} }
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
func (flags *Flags) IsTokens() bool { func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true" return flags.Lookup("tokens").Value.String() == "true"
} }
@ -111,10 +107,6 @@ func (flags *Flags) SetParams(context whisper.Context) error {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration) fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration) context.SetDuration(duration)
} }
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 { if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads) fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads) context.SetThreads(threads)
@ -146,7 +138,6 @@ func registerFlags(flag *Flags) {
flag.Duration("offset", 0, "Time offset") flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process") flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use") flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters") flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment") flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score") flag.Float64("word-thold", 0, "Maximum segment score")

View File

@ -47,10 +47,6 @@ func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v) p.print_timestamps = toBool(v)
} }
func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}
// Set language id // Set language id
func (p *Params) SetLanguage(lang int) error { func (p *Params) SetLanguage(lang int) error {
if lang == -1 { if lang == -1 {
@ -177,9 +173,6 @@ func (p *Params) String() string {
if p.token_timestamps { if p.token_timestamps {
str += " token_timestamps" str += " token_timestamps"
} }
if p.speed_up {
str += " speed_up"
}
return str + ">" return str + ">"
} }

View File

@ -76,11 +76,6 @@ func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v) context.params.SetTranslate(v)
} }
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
func (context *context) SetSplitOnWord(v bool) { func (context *context) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v) context.params.SetSplitOnWord(v)
} }

View File

@ -41,7 +41,6 @@ type Context interface {
SetOffset(time.Duration) // Set offset SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetSplitOnWord(bool) // Set split on word flag SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold

View File

@ -20,7 +20,7 @@ public interface WhisperCppJnaLibrary extends Library {
* @return Whisper context on success, null on failure * @return Whisper context on success, null on failure
*/ */
Pointer whisper_init_from_file(String path_model); Pointer whisper_init_from_file(String path_model);
/** /**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc. * Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either: * Because this function allocates memory for the params, the caller must call either:
@ -304,14 +304,6 @@ public interface WhisperCppJnaLibrary extends Library {
/** Language id associated with the provided state */ /** Language id associated with the provided state */
int whisper_full_lang_id_from_state(Pointer state); int whisper_full_lang_id_from_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
* @return 0 on success
*/
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/** Get the start time of the specified segment. */ /** Get the start time of the specified segment. */
long whisper_full_get_segment_t0(Pointer ctx, int i_segment); long whisper_full_get_segment_t0(Pointer ctx, int i_segment);

View File

@ -129,14 +129,6 @@ public class WhisperFullParams extends Structure {
/** Maximum tokens per segment (0, default = no limit) */ /** Maximum tokens per segment (0, default = no limit) */
public int max_tokens; public int max_tokens;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public CBool speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public void speedUp(boolean enable) {
speed_up = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */ /** Overwrite the audio context size (0 = use default). */
public int audio_ctx; public int audio_ctx;
@ -321,7 +313,7 @@ public class WhisperFullParams extends Structure {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate", return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment", "no_timestamps", "no_context", "single_segment", "no_timestamps",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.6.0", "version": "1.6.2",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

12
bindings/ruby/Rakefile Normal file
View File

@ -0,0 +1,12 @@
require 'rake/clean'
require 'rubygems/package'
desc 'Build gem'
task :package do
spec_source = File.read File.join(File.dirname(__FILE__),'whispercpp.gemspec')
spec = nil
# see: http://gist.github.com/16215
Thread.new { spec = eval("#{spec_source}") }.join
spec.validate
Gem::Package.build(spec)
end

View File

@ -1,6 +1,7 @@
require 'mkmf' require 'mkmf'
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")

View File

@ -12,31 +12,63 @@ extern "C" {
// Backend buffer // Backend buffer
// //
// buffer type
typedef void * ggml_backend_buffer_type_context_t;
struct ggml_backend_buffer_type_i {
const char * (*GGML_CALL get_name) (ggml_backend_buffer_type_t buft);
ggml_backend_buffer_t (*GGML_CALL alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft); // allocation max size
size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
bool (*GGML_CALL supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
// check if tensor data is in host memory
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
bool (*GGML_CALL is_host) (ggml_backend_buffer_type_t buft);
};
struct ggml_backend_buffer_type {
struct ggml_backend_buffer_type_i iface;
ggml_backend_buffer_type_context_t context;
};
// buffer
typedef void * ggml_backend_buffer_context_t; typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i { struct ggml_backend_buffer_i {
void (*free_buffer) (ggml_backend_buffer_t buffer); const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
}; };
struct ggml_backend_buffer { struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface; struct ggml_backend_buffer_i iface;
ggml_backend_buffer_type_t buft;
ggml_backend_t backend;
ggml_backend_buffer_context_t context; ggml_backend_buffer_context_t context;
size_t size; size_t size;
enum ggml_backend_buffer_usage usage;
}; };
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend * backend, ggml_backend_buffer_type_t buft,
struct ggml_backend_buffer_i iface, struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context, ggml_backend_buffer_context_t context,
size_t size); size_t size);
// do not use directly, use ggml_backend_tensor_copy instead
bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
// buffer that contains a collection of buffers
GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
// //
// Backend // Backend
// //
@ -44,44 +76,66 @@ extern "C" {
typedef void * ggml_backend_context_t; typedef void * ggml_backend_context_t;
struct ggml_backend_i { struct ggml_backend_i {
const char * (*get_name)(ggml_backend_t backend); const char * (*GGML_CALL get_name)(ggml_backend_t backend);
void (*free)(ggml_backend_t backend); void (*GGML_CALL free)(ggml_backend_t backend);
// buffer allocation // buffer allocation
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size); ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend);
// get buffer alignment // (optional) asynchronous tensor data access
size_t (*get_alignment)(ggml_backend_t backend); void (*GGML_CALL set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
// tensor data access // (optional) complete all pending operations
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize void (*GGML_CALL synchronize)(ggml_backend_t backend);
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*synchronize) (ggml_backend_t backend);
// (optional) copy tensor between different backends, allow for single-copy tranfers // compute graph with a plan (not used currently)
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph with a plan // compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); // compute graph without a plan (async)
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
// compute graph without a plan
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
// check if the backend supports an operation // check if the backend supports an operation
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
// check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
// these should be expensive operations with large batch sizes that may benefit from running on this backend
// even if the weight has to be copied from the CPU temporarily
bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
// (optional) event synchronization
ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend);
void (*GGML_CALL event_free) (ggml_backend_event_t event);
void (*GGML_CALL event_record) (ggml_backend_event_t event);
void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
void (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
}; };
struct ggml_backend { struct ggml_backend {
struct ggml_backend_i iface; ggml_guid_t guid;
struct ggml_backend_i iface;
ggml_backend_context_t context; ggml_backend_context_t context;
}; };
struct ggml_backend_event {
ggml_backend_t backend;
void * context;
};
//
// Backend registry
//
typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data);
GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

File diff suppressed because it is too large Load Diff

View File

@ -7,69 +7,123 @@
extern "C" { extern "C" {
#endif #endif
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend_event * ggml_backend_event_t;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
// //
// Backend buffer // Backend buffer
// //
struct ggml_backend_buffer; // buffer type
typedef struct ggml_backend_buffer * ggml_backend_buffer_t; GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
// backend buffer functions // buffer
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); enum ggml_backend_buffer_usage {
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); GGML_BACKEND_BUFFER_USAGE_ANY = 0,
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); };
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
// //
// Backend // Backend
// //
struct ggml_backend; GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
GGML_API const char * ggml_backend_name(ggml_backend_t backend); GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend); GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_synchronize(ggml_backend_t backend); GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
// tensor copy between different backends // tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
// asynchronous copy
// the copy is performed after all the currently queued operations in backend_src
// backend_dst will wait for the copy to complete before performing other operations
// automatic fallback to sync copy if async is not supported
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
// events
GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend);
GGML_API void ggml_backend_event_free (ggml_backend_event_t event);
GGML_API void ggml_backend_event_record (ggml_backend_event_t event);
GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event);
GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); // wait async on event
// //
// CPU backend // CPU backend
// //
GGML_API ggml_backend_t ggml_backend_cpu_init(void); GGML_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
// Create a backend buffer from an existing pointer // Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size); GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
#ifdef GGML_USE_CPU_HBM
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
//
// Backend registry
//
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
GGML_API size_t ggml_backend_reg_get_count(void);
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
GGML_API const char * ggml_backend_reg_get_name(size_t i);
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
// //
// Backend scheduler // Backend scheduler
@ -83,53 +137,96 @@ extern "C" {
/* /*
Example usage: Example usage:
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
// sched is initialized with measure allocators and cannot be used until allocated with a measure graph // preferrably to run on the same backend as the buffer
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
// initialize buffers from a measure graph sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
// in build_graph: // initialize buffers from a max size graph (optional)
build_graph(...) { reserve_graph = build_graph(sched, max_batch_size);
// allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
ggml_allocr_alloc(alloc_cpu, tensor);
// manually assigning nodes to a backend (optional, shouldn't be needed in most cases) // manually assign nodes to a backend (optional, should not be needed in most cases)
struct ggml_tensor * node = ggml_mul_mat(ctx, ...); struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
ggml_backend_sched_set_node_backend(sched, node, backend_gpu); ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
}
// allocate backend buffers from measure graph ggml_backend_sched_reserve(sched, reserve_graph);
ggml_backend_sched_init_measure(sched, measure_graph);
// the scheduler is now ready to compute graphs
// compute // compute
graph = build_graph(sched); graph = build_graph(sched);
ggml_backend_sched_graph_compute(sched, graph); ggml_backend_sched_graph_compute(sched, graph);
// if there are graph inputs:
ggml_backend_sched_reset(sched);
ggml_backend_sched_alloc_graph(sched, graph);
ggml_backend_tensor_set(input_tensor, ...);
ggml_backend_sched_graph_compute(sched, graph);
}
*/ */
struct ggml_backend_sched; struct ggml_backend_sched;
typedef struct ggml_backend_sched * ggml_backend_sched_t; typedef struct ggml_backend_sched * ggml_backend_sched_t;
// Initialize a backend scheduler // when ask == true, the scheduler wants to know if the user wants to observe this node
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends); // this allows the scheduler to batch nodes together in order to evaluate them in a single call
//
// when ask == false, the scheduler is passing the node tensor to the user for observation
// if the user returns false, the scheduler will cancel the graph compute
//
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph // Initialize backend buffers from a measure graph
GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend); // Get the number of splits of the last graph
GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend); GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
// Reset all assignments and allocators - must be called before changing the node backends
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
//
// Utils
//
struct ggml_backend_graph_copy {
ggml_backend_buffer_t buffer;
struct ggml_context * ctx_allocated;
struct ggml_context * ctx_unallocated;
struct ggml_cgraph * graph;
};
// Copy a graph to a different backend
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
// Tensor initialization
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
// Allocate a graph on the backend scheduler
GGML_API void ggml_backend_sched_graph_compute(
ggml_backend_sched_t sched,
struct ggml_cgraph * graph);
#ifdef __cplusplus #ifdef __cplusplus
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef GGML_USE_HIPBLAS
#define GGML_CUDA_NAME "ROCm"
#define GGML_CUBLAS_NAME "hipBLAS"
#else
#define GGML_CUDA_NAME "CUDA"
#define GGML_CUBLAS_NAME "cuBLAS"
#endif
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_CUDA_MAX_DEVICES 16
// backend API
GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
// device buffer
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// split tensor buffer that splits matrices by rows across multiple devices
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
#ifdef __cplusplus
}
#endif

View File

@ -5,6 +5,7 @@
// GGML internal header // GGML internal header
#include <assert.h> #include <assert.h>
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
#include <stddef.h> #include <stddef.h>
#include <stdbool.h> #include <stdbool.h>
#include <string.h> // memcpy #include <string.h> // memcpy
@ -18,6 +19,7 @@ extern "C" {
// fall back to the _Static_assert C11 keyword. // fall back to the _Static_assert C11 keyword.
// if C99 - static_assert is noop // if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976 // ref: https://stackoverflow.com/a/53923785/4039976
#ifndef __cplusplus
#ifndef static_assert #ifndef static_assert
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
#define static_assert(cond, msg) _Static_assert(cond, msg) #define static_assert(cond, msg) _Static_assert(cond, msg)
@ -25,6 +27,7 @@ extern "C" {
#define static_assert(cond, msg) struct global_scope_noop_trick #define static_assert(cond, msg) struct global_scope_noop_trick
#endif #endif
#endif #endif
#endif
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
@ -34,16 +37,17 @@ extern "C" {
#ifndef __F16C__ #ifndef __F16C__
#define __F16C__ #define __F16C__
#endif #endif
#endif
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
#ifndef __SSE3__ #ifndef __SSE3__
#define __SSE3__ #define __SSE3__
#endif #endif
#ifndef __SSSE3__
#define __SSSE3__
#endif
#endif #endif
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
// 16-bit float // 16-bit float
// on Arm, we use __fp16 // on Arm, we use __fp16
@ -56,14 +60,30 @@ extern "C" {
// //
#include <arm_neon.h> #include <arm_neon.h>
#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) typedef __fp16 ggml_fp16_internal_t;
#define GGML_COMPUTE_FP32_TO_FP16(x) (x)
#define GGML_FP16_TO_FP32(x) ((float) (x)) #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
#define GGML_FP32_TO_FP16(x) (x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
ggml_fp16_internal_t tmp;
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
return (float)tmp;
}
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
ggml_fp16_t res;
ggml_fp16_internal_t tmp = f;
memcpy(&res, &tmp, sizeof(ggml_fp16_t));
return res;
}
#else #else
typedef uint16_t ggml_fp16_internal_t;
#ifdef __wasm_simd128__ #ifdef __wasm_simd128__
#include <wasm_simd128.h> #include <wasm_simd128.h>
#else #else
@ -217,8 +237,7 @@ extern float ggml_table_f32_f16[1 << 16];
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9. // This is also true for POWER9.
#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) #if !defined(GGML_FP16_TO_FP32)
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
uint16_t s; uint16_t s;
memcpy(&s, &f, sizeof(uint16_t)); memcpy(&s, &f, sizeof(uint16_t));
@ -226,19 +245,23 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
} }
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) #endif
#if !defined(GGML_FP32_TO_FP16)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
#endif #endif
#define GGML_HASHTABLE_FULL ((size_t)-1) #define GGML_HASHTABLE_FULL ((size_t)-1)
#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) #define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
struct ggml_hash_set ggml_hash_set_new(size_t size);
bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key); bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted // returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full // returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
// return index, asserts if table is full // return index, asserts if table is full

View File

@ -0,0 +1,46 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_vk_device {
int index;
int type; // same as VkPhysicalDeviceType
size_t heapSize;
const char * name;
const char * vendor;
int subgroupSize;
uint64_t bufferAlignment;
uint64_t maxAlloc;
};
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
bool ggml_vk_has_vulkan(void);
bool ggml_vk_has_device(void);
struct ggml_vk_device ggml_vk_current_device(void);
//
// backend API
//
// forward declaration
typedef struct ggml_backend * ggml_backend_t;
GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,66 @@
// An interface allowing to compute ggml_cgraph with Metal
//
// This is a fully functional interface that extends ggml with GPU support for Apple devices.
// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
//
// How it works?
//
// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
//
// You only need to make sure that all memory buffers that you used during the graph creation
// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
// used during the graph evaluation to determine the arguments of the compute kernels.
//
// Synchronization between device and host memory (for example for input and output tensors)
// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
//
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <stddef.h>
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
struct ggml_tensor;
struct ggml_cgraph;
#ifdef __cplusplus
extern "C" {
#endif
//
// backend API
// user-code should use only these functions
//
GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
GGML_API ggml_backend_t ggml_backend_metal_init(void);
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,36 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
GGML_API void ggml_cl_init(void);
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cl_add(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
// GGML_API void * ggml_cl_host_malloc(size_t size);
// GGML_API void ggml_cl_host_free(void * ptr);
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
// backend API
// GGML_API ggml_backend_t ggml_backend_opencl_init(void);
// GGML_API bool ggml_backend_is_opencl(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
// GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -1,224 +1,133 @@
#pragma once #pragma once
#include "ggml-impl.h" #define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml.h"
// GGML internal header // GGML internal header
#include <stdint.h> #ifdef __cplusplus
#include <stddef.h> extern "C" {
#define QK4_0 32
typedef struct {
ggml_fp16_t d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
typedef struct {
ggml_fp16_t d; // delta
ggml_fp16_t m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK5_0 32
typedef struct {
ggml_fp16_t d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
typedef struct {
ggml_fp16_t d; // delta
ggml_fp16_t m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
typedef struct {
ggml_fp16_t d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
#define QK8_1 32
typedef struct {
float d; // delta
float s; // d * sum(qs[i])
int8_t qs[QK8_1]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
//
// Super-block quantization structures
//
// Super-block size
#ifdef GGML_QKK_64
#define QK_K 64
#define K_SCALE_SIZE 4
#else
#define QK_K 256
#define K_SCALE_SIZE 12
#endif #endif
// 2-bit quantization
// weight is represented as x = a * q + b
// 16 blocks of 16 elements each
// Effectively 2.5625 bits per weight
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
// 3-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 3.4375 bits per weight
#ifdef GGML_QKK_64
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[2];
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
#else
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[12]; // scales, quantized with 6 bits
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
#endif
// 4-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d[2]; // super-block scales/mins
uint8_t scales[2]; // 4-bit block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
#endif
// 5-bit quantization
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d; // super-block scale
int8_t scales[QK_K/16]; // 8-bit block scales
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
ggml_fp16_t d; // super-block scale
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
// This is only used for intermediate quantization and dot products
typedef struct {
float d; // delta
int8_t qs[QK_K]; // quants
int16_t bsums[QK_K/16]; // sum of quants in groups of 16
} block_q8_K;
static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
// Quantization // Quantization
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k); void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k); void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k); void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k); void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0(const float * restrict x, void * restrict y, int k); void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_1(const float * restrict x, void * restrict y, int k); void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_0(const float * restrict x, void * restrict y, int k); void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization // Dequantization
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k); void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k); void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k); void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k); //void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// Dot product // Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void iq2xs_init_impl(enum ggml_type type);
void iq2xs_free_impl(enum ggml_type type);
void iq3xs_init_impl(int grid_size);
void iq3xs_free_impl(int grid_size);
#ifdef __cplusplus
}
#endif
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);

View File

@ -0,0 +1,49 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_SYCL_MAX_DEVICES 48
#define GGML_SYCL_NAME "SYCL"
// backend API
GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
// devide buffer
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
// split tensor buffer that splits matrices by rows across multiple devices
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
GGML_API void ggml_backend_sycl_print_sycl_devices(void);
GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len);
GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_count();
GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
GGML_API GGML_CALL int ggml_backend_sycl_get_device_index(int device_id);
// TODO: these are temporary
// ref: https://github.com/ggerganov/llama.cpp/pull/6022#issuecomment-1992615670
GGML_API GGML_CALL int ggml_backend_sycl_get_device_id(int device_index);
GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id);
GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode();
// SYCL doesn't support registering host memory, keep here for reference
// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer);
#ifdef __cplusplus
}
#endif

View File

@ -0,0 +1,29 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_VK_NAME "Vulkan"
#define GGML_VK_MAX_DEVICES 16
GGML_API void ggml_vk_instance_init(void);
// backend API
GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void);
GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
#ifdef __cplusplus
}
#endif

View File

@ -311,12 +311,6 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, split_on_word, value) BOOL_PARAMS_SETTER(self, split_on_word, value)
} }
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
BOOL_PARAMS_GETTER(self, speed_up)
}
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, speed_up, value)
}
static VALUE ruby_whisper_params_get_diarize(VALUE self) { static VALUE ruby_whisper_params_get_diarize(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
@ -408,8 +402,6 @@ void Init_whisper() {
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);

View File

@ -117,13 +117,6 @@ class TestWhisper < Test::Unit::TestCase
assert !@params.split_on_word assert !@params.split_on_word
end end
def test_speed_up
@params.speed_up = true
assert @params.speed_up
@params.speed_up = false
assert !@params.speed_up
end
def test_whisper def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new params = Whisper::Params.new

View File

@ -0,0 +1,28 @@
Gem::Specification.new do |s|
s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.0'
s.date = '2024-05-14'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md']
s.files = ["LICENSE", "README.md", "Rakefile", "ext/extconf.rb", "ext/ggml.c", "ext/ruby_whisper.cpp", "ext/whisper.cpp", "ext/dr_wav.h", "ext/ggml.h", "ext/ruby_whisper.h", "ext/whisper.h"]
#### Load-time details
s.require_paths = ['lib','ext']
s.summary = %q{Ruby whisper.cpp bindings}
s.test_files = ["tests/test_whisper.rb"]
s.extensions << 'ext/extconf.rb'
#### Documentation and testing.
s.homepage = 'https://github.com/ggerganov/whisper.cpp'
s.rdoc_options = ['--main', '../../README.md']
s.platform = Gem::Platform::RUBY
s.licenses = ['MIT']
end

163
cmake/FindFFmpeg.cmake Normal file
View File

@ -0,0 +1,163 @@
# From
# https://github.com/snikulov/cmake-modules/blob/master/FindFFmpeg.cmake
#
# vim: ts=2 sw=2
# - Try to find the required ffmpeg components(default: AVFORMAT, AVUTIL, AVCODEC)
#
# Once done this will define
# FFMPEG_FOUND - System has the all required components.
# FFMPEG_INCLUDE_DIRS - Include directory necessary for using the required components headers.
# FFMPEG_LIBRARIES - Link these to use the required ffmpeg components.
# FFMPEG_DEFINITIONS - Compiler switches required for using the required ffmpeg components.
#
# For each of the components it will additionally set.
# - AVCODEC
# - AVDEVICE
# - AVFORMAT
# - AVFILTER
# - AVUTIL
# - POSTPROC
# - SWSCALE
# the following variables will be defined
# <component>_FOUND - System has <component>
# <component>_INCLUDE_DIRS - Include directory necessary for using the <component> headers
# <component>_LIBRARIES - Link these to use <component>
# <component>_DEFINITIONS - Compiler switches required for using <component>
# <component>_VERSION - The components version
#
# Copyright (c) 2006, Matthias Kretz, <kretz@kde.org>
# Copyright (c) 2008, Alexander Neundorf, <neundorf@kde.org>
# Copyright (c) 2011, Michael Jansen, <kde@michael-jansen.biz>
#
# Redistribution and use is allowed according to the terms of the BSD license.
# For details see the accompanying COPYING-CMAKE-SCRIPTS file.
include(FindPackageHandleStandardArgs)
# The default components were taken from a survey over other FindFFMPEG.cmake files
if (NOT FFmpeg_FIND_COMPONENTS)
set(FFmpeg_FIND_COMPONENTS AVFORMAT AVCODEC AVUTIL SWRESAMPLE)
endif()
#
### Macro: set_component_found
#
# Marks the given component as found if both *_LIBRARIES AND *_INCLUDE_DIRS is present.
#
macro(set_component_found _component )
if (${_component}_LIBRARIES AND ${_component}_INCLUDE_DIRS)
message(DEBUG " - ${_component} found.")
set(${_component}_FOUND TRUE)
else ()
message(DEBUG " - ${_component} not found.")
endif ()
endmacro()
#
### Macro: find_component
#
# Checks for the given component by invoking pkgconfig and then looking up the libraries and
# include directories.
#
macro(find_component _component _pkgconfig _library _header)
if (NOT WIN32)
# use pkg-config to get the directories and then use these values
# in the FIND_PATH() and FIND_LIBRARY() calls
find_package(PkgConfig)
if (PKG_CONFIG_FOUND)
pkg_check_modules(PC_${_component} ${_pkgconfig})
message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDEDIR}")
message(STATUS "Pkgconfig found: ${PC_${_component}_INCLUDE_DIRS}")
message(STATUS "${PC_${_component}_CFLAGS}")
endif ()
endif (NOT WIN32)
find_path(${_component}_INCLUDE_DIRS ${_header}
HINTS
${PC_${_component}_INCLUDEDIR}
${PC_${_component}_INCLUDE_DIRS}
PATH_SUFFIXES
ffmpeg
)
# CMake's default is to search first for shared libraries and then for static libraries.
# Todo later: add option to prefer static libs over dynamic:
find_library(${_component}_LIBRARIES NAMES ${_library} lib${_library}.a
HINTS
${PC_${_component}_LIBDIR}
${PC_${_component}_LIBRARY_DIRS}
)
set(${_component}_DEFINITIONS ${PC_${_component}_CFLAGS_OTHER} CACHE STRING "The ${_component} CFLAGS.")
set(${_component}_VERSION ${PC_${_component}_VERSION} CACHE STRING "The ${_component} version number.")
set_component_found(${_component})
mark_as_advanced(
${_component}_INCLUDE_DIRS
${_component}_LIBRARIES
${_component}_DEFINITIONS
${_component}_VERSION)
endmacro()
# Check for cached results. If there are skip the costly part.
if (NOT FFMPEG_LIBRARIES)
# Check for all possible component.
find_component(AVCODEC libavcodec avcodec libavcodec/avcodec.h)
find_component(AVFORMAT libavformat avformat libavformat/avformat.h)
find_component(AVDEVICE libavdevice avdevice libavdevice/avdevice.h)
#find_component(AVRESAMPLE libavresample avresample libavresample/avresample.h) # old name for swresample
find_component(AVUTIL libavutil avutil libavutil/avutil.h)
find_component(AVFILTER libavfilter avfilter libavfilter/avfilter.h)
find_component(SWSCALE libswscale swscale libswscale/swscale.h)
find_component(POSTPROC libpostproc postproc libpostproc/postprocess.h)
find_component(SWRESAMPLE libswresample swresample libswresample/swresample.h)
# Check if the required components were found and add their stuff to the FFMPEG_* vars.
foreach (_component ${FFmpeg_FIND_COMPONENTS})
if (${_component}_FOUND)
# message(STATUS "Required component ${_component} present.")
set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} ${${_component}_LIBRARIES})
set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} ${${_component}_DEFINITIONS})
list(APPEND FFMPEG_INCLUDE_DIRS ${${_component}_INCLUDE_DIRS})
else ()
# message(STATUS "Required component ${_component} missing.")
endif ()
endforeach ()
# Build the include path with duplicates removed.
if (FFMPEG_INCLUDE_DIRS)
list(REMOVE_DUPLICATES FFMPEG_INCLUDE_DIRS)
endif ()
# cache the vars.
set(FFMPEG_INCLUDE_DIRS ${FFMPEG_INCLUDE_DIRS} CACHE STRING "The FFmpeg include directories." FORCE)
set(FFMPEG_LIBRARIES ${FFMPEG_LIBRARIES} CACHE STRING "The FFmpeg libraries." FORCE)
set(FFMPEG_DEFINITIONS ${FFMPEG_DEFINITIONS} CACHE STRING "The FFmpeg cflags." FORCE)
mark_as_advanced(FFMPEG_INCLUDE_DIRS
FFMPEG_LIBRARIES
FFMPEG_DEFINITIONS)
endif ()
# Now set the noncached _FOUND vars for the components.
# whisper.cpp does not need SWSCALE
foreach (_component AVCODEC AVDEVICE AVFORMAT AVRESAMPLE AVUTIL POSTPROCESS)
set_component_found(${_component})
endforeach ()
# Compile the list of required vars
set(_FFmpeg_REQUIRED_VARS FFMPEG_LIBRARIES FFMPEG_INCLUDE_DIRS)
foreach (_component ${FFmpeg_FIND_COMPONENTS})
list(APPEND _FFmpeg_REQUIRED_VARS ${_component}_LIBRARIES ${_component}_INCLUDE_DIRS)
endforeach ()
# Give a nice error message if some of the required vars are missing.
find_package_handle_standard_args(FFmpeg DEFAULT_MSG ${_FFmpeg_REQUIRED_VARS})

View File

@ -22,6 +22,10 @@ endif()
set(TARGET common) set(TARGET common)
if (WHISPER_FFMPEG)
set(COMMON_SOURCES_FFMPEG ffmpeg-transcode.cpp)
endif()
add_library(${TARGET} STATIC add_library(${TARGET} STATIC
common.h common.h
common.cpp common.cpp
@ -29,6 +33,7 @@ add_library(${TARGET} STATIC
common-ggml.cpp common-ggml.cpp
grammar-parser.h grammar-parser.h
grammar-parser.cpp grammar-parser.cpp
${COMMON_SOURCES_FFMPEG}
) )
include(DefaultTargetOptions) include(DefaultTargetOptions)

View File

@ -12,6 +12,7 @@ const whisperParamsMock = {
model: path.join(__dirname, "../../../models/ggml-base.en.bin"), model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true, use_gpu: true,
flash_attn: false,
no_prints: true, no_prints: true,
comma_in_time: false, comma_in_time: false,
translate: true, translate: true,

View File

@ -25,7 +25,6 @@ struct whisper_params {
float entropy_thold = 2.4f; float entropy_thold = 2.4f;
float logprob_thold = -1.0f; float logprob_thold = -1.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool diarize = false; bool diarize = false;
bool output_txt = false; bool output_txt = false;
@ -39,6 +38,7 @@ struct whisper_params {
bool no_timestamps = false; bool no_timestamps = false;
bool no_prints = false; bool no_prints = false;
bool use_gpu = true; bool use_gpu = true;
bool flash_attn = false;
bool comma_in_time = true; bool comma_in_time = true;
std::string language = "en"; std::string language = "en";
@ -146,6 +146,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
struct whisper_context_params cparams = whisper_context_default_params(); struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) { if (ctx == nullptr) {
@ -230,8 +231,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of; wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size; wparams.beam_search.beam_size = params.beam_size;
@ -326,6 +325,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string model = whisper_params.Get("model").As<Napi::String>(); std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>(); std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>(); bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>(); bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>(); bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
@ -346,6 +346,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.model = model; params.model = model;
params.fname_inp.emplace_back(input); params.fname_inp.emplace_back(input);
params.use_gpu = use_gpu; params.use_gpu = use_gpu;
params.flash_attn = flash_attn;
params.no_prints = no_prints; params.no_prints = no_prints;
params.no_timestamps = no_timestamps; params.no_timestamps = no_timestamps;
params.audio_ctx = audio_ctx; params.audio_ctx = audio_ctx;

View File

@ -12,6 +12,7 @@ const whisperParams = {
model: path.join(__dirname, "../../models/ggml-base.en.bin"), model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"), fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
use_gpu: true, use_gpu: true,
flash_attn: false,
no_prints: true, no_prints: true,
comma_in_time: false, comma_in_time: false,
translate: true, translate: true,

View File

@ -38,7 +38,6 @@ struct whisper_params {
grammar_parser::parse_state grammar_parsed; grammar_parser::parse_state grammar_parsed;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -76,7 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -115,7 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -165,7 +162,6 @@ std::string transcribe(
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature = 0.4f; wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f; wparams.temperature_inc = 1.0f;
@ -371,7 +367,6 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data(); wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size(); wparams.prompt_n_tokens = k_tokens.size();

View File

@ -24,6 +24,11 @@
#include <io.h> #include <io.h>
#endif #endif
#ifdef WHISPER_FFMPEG
// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support
extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data);
#endif
// Function to check if the next argument exists // Function to check if the next argument exists
std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
if (i + 1 < argc && argv[i + 1][0] != '-') { if (i + 1 < argc && argv[i + 1][0] != '-') {
@ -637,7 +642,7 @@ bool is_wav_buffer(const std::string buf) {
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) { bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav; drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin std::vector<uint8_t> wav_data; // used for pipe input from stdin or ffmpeg decoding output
if (fname == "-") { if (fname == "-") {
{ {
@ -670,8 +675,19 @@ bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector
} }
} }
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
#if defined(WHISPER_FFMPEG)
if (ffmpeg_decode_audio(fname, wav_data) != 0) {
fprintf(stderr, "error: failed to ffmpeg decode '%s' \n", fname.c_str());
return false;
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to read wav data as wav \n");
return false;
}
#else
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
return false; return false;
#endif
} }
if (wav.channels != 1 && wav.channels != 2) { if (wav.channels != 1 && wav.channels != 2) {

View File

@ -185,7 +185,7 @@ private:
// It is assumed that PCM data is normalized to a range from -1 to 1 // It is assumed that PCM data is normalized to a range from -1 to 1
bool write_audio(const float * data, size_t length) { bool write_audio(const float * data, size_t length) {
for (size_t i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
const int16_t intSample = data[i] * 32767; const int16_t intSample = int16_t(data[i] * 32767);
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t)); file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
dataSize += sizeof(int16_t); dataSize += sizeof(int16_t);
} }

View File

@ -0,0 +1,350 @@
/* SPDX-License-Identifier: GPL-2.0 */
/*
* transcode.c - convert audio file to WAVE
*
* Copyright (C) 2019 Andrew Clayton <andrew@digital-domain.net>
* Copyright (C) 2024 William Tambellini <william.tambellini@gmail.com>
*/
// Just for conveninent C++ API
#include <vector>
#include <string>
// C
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <stdint.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/mman.h>
extern "C" {
#include <libavutil/opt.h>
#include <libavcodec/avcodec.h>
#include <libavformat/avformat.h>
#include <libswresample/swresample.h>
}
typedef uint64_t u64;
typedef int64_t s64;
typedef uint32_t u32;
typedef int32_t s32;
typedef uint16_t u16;
typedef int16_t s16;
typedef uint8_t u8;
typedef int8_t s8;
#define WAVE_SAMPLE_RATE 16000
#define AVIO_CTX_BUF_SZ 4096
static const char* ffmpegLog = getenv("FFMPEG_LOG");
// Todo: add __FILE__ __LINE__
#define LOG(...) \
do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99
/*
* WAVE file header based on definition from
* https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f
*
* We must ensure this structure doesn't have any holes or
* padding so we can just map it straight to the WAVE data.
*/
struct wave_hdr {
/* RIFF Header: "RIFF" */
char riff_header[4];
/* size of audio data + sizeof(struct wave_hdr) - 8 */
int wav_size;
/* "WAVE" */
char wav_header[4];
/* Format Header */
/* "fmt " (includes trailing space) */
char fmt_header[4];
/* Should be 16 for PCM */
int fmt_chunk_size;
/* Should be 1 for PCM. 3 for IEEE Float */
s16 audio_format;
s16 num_channels;
int sample_rate;
/*
* Number of bytes per second
* sample_rate * num_channels * bit_depth/8
*/
int byte_rate;
/* num_channels * bytes per sample */
s16 sample_alignment;
/* bits per sample */
s16 bit_depth;
/* Data Header */
/* "data" */
char data_header[4];
/*
* size of audio
* number of samples * num_channels * bit_depth/8
*/
int data_bytes;
} __attribute__((__packed__));
struct audio_buffer {
u8 *ptr;
int size; /* size left in the buffer */
};
static void set_wave_hdr(wave_hdr& wh, size_t size) {
memcpy(&wh.riff_header, "RIFF", 4);
wh.wav_size = size + sizeof(struct wave_hdr) - 8;
memcpy(&wh.wav_header, "WAVE", 4);
memcpy(&wh.fmt_header, "fmt ", 4);
wh.fmt_chunk_size = 16;
wh.audio_format = 1;
wh.num_channels = 1;
wh.sample_rate = WAVE_SAMPLE_RATE;
wh.sample_alignment = 2;
wh.bit_depth = 16;
wh.byte_rate = wh.sample_rate * wh.sample_alignment;
memcpy(&wh.data_header, "data", 4);
wh.data_bytes = size;
}
static void write_wave_hdr(int fd, size_t size) {
struct wave_hdr wh;
set_wave_hdr(wh, size);
write(fd, &wh, sizeof(struct wave_hdr));
}
static int map_file(int fd, u8 **ptr, size_t *size)
{
struct stat sb;
fstat(fd, &sb);
*size = sb.st_size;
*ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0);
if (*ptr == MAP_FAILED) {
perror("mmap");
return -1;
}
return 0;
}
static int read_packet(void *opaque, u8 *buf, int buf_size)
{
struct audio_buffer *audio_buf = (audio_buffer*)opaque;
buf_size = FFMIN(buf_size, audio_buf->size);
/* copy internal buffer data to buf */
memcpy(buf, audio_buf->ptr, buf_size);
audio_buf->ptr += buf_size;
audio_buf->size -= buf_size;
return buf_size;
}
static void convert_frame(struct SwrContext *swr, AVCodecContext *codec,
AVFrame *frame, s16 **data, int *size, bool flush)
{
int nr_samples;
s64 delay;
u8 *buffer;
delay = swr_get_delay(swr, codec->sample_rate);
nr_samples = av_rescale_rnd(delay + frame->nb_samples,
WAVE_SAMPLE_RATE, codec->sample_rate,
AV_ROUND_UP);
av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0);
/*
* !flush is used to check if we are flushing any remaining
* conversion buffers...
*/
nr_samples = swr_convert(swr, &buffer, nr_samples,
!flush ? (const u8 **)frame->data : NULL,
!flush ? frame->nb_samples : 0);
*data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16));
memcpy(*data + *size, buffer, nr_samples * sizeof(s16));
*size += nr_samples;
av_freep(&buffer);
}
static bool is_audio_stream(const AVStream *stream)
{
if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO)
return true;
return false;
}
// Return non zero on error, 0 on success
// audio_buffer: input memory
// data: decoded output audio data (wav file)
// size: size of output data
static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
{
LOG("decode_audio: input size: %d\n", audio_buf->size);
AVFormatContext *fmt_ctx;
AVIOContext *avio_ctx;
AVStream *stream;
AVCodecContext *codec;
AVPacket packet;
AVFrame *frame;
struct SwrContext *swr;
u8 *avio_ctx_buffer;
unsigned int i;
int stream_index = -1;
int err;
const size_t errbuffsize = 1024;
char errbuff[errbuffsize];
av_register_all(); // from avformat. Still a must-have call for ffmpeg v3! (can be skipped for later versions)
fmt_ctx = avformat_alloc_context();
avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ);
LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ);
avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL);
fmt_ctx->pb = avio_ctx;
// open the input stream and read header
err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL);
if (err) {
LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err));
return err;
}
err = avformat_find_stream_info(fmt_ctx, NULL);
if (err < 0) {
LOG("Could not retrieve stream info from audio buffer: %d\n", err);
return err;
}
for (i = 0; i < fmt_ctx->nb_streams; i++) {
if (is_audio_stream(fmt_ctx->streams[i])) {
stream_index = i;
break;
}
}
if (stream_index == -1) {
LOG("Could not retrieve audio stream from buffer\n");
return -1;
}
stream = fmt_ctx->streams[stream_index];
codec = avcodec_alloc_context3(
avcodec_find_decoder(stream->codecpar->codec_id));
avcodec_parameters_to_context(codec, stream->codecpar);
err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id),
NULL);
if (err) {
LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index);
return err;
}
/* prepare resampler */
swr = swr_alloc();
av_opt_set_int(swr, "in_channel_count", codec->channels, 0);
av_opt_set_int(swr, "out_channel_count", 1, 0);
av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0);
av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0);
av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0);
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
swr_init(swr);
if (!swr_is_initialized(swr)) {
LOG("Resampler has not been properly initialized\n");
return -1;
}
av_init_packet(&packet);
frame = av_frame_alloc();
if (!frame) {
LOG("Error allocating the frame\n");
return -1;
}
/* iterate through frames */
*data = NULL;
*size = 0;
while (av_read_frame(fmt_ctx, &packet) >= 0) {
avcodec_send_packet(codec, &packet);
err = avcodec_receive_frame(codec, frame);
if (err == AVERROR(EAGAIN))
continue;
convert_frame(swr, codec, frame, data, size, false);
}
/* Flush any remaining conversion buffers... */
convert_frame(swr, codec, frame, data, size, true);
av_frame_free(&frame);
swr_free(&swr);
//avio_context_free(); // todo?
avcodec_close(codec);
avformat_close_input(&fmt_ctx);
avformat_free_context(fmt_ctx);
if (avio_ctx) {
av_freep(&avio_ctx->buffer);
av_freep(&avio_ctx);
}
return 0;
}
// in mem decoding/conversion/resampling:
// ifname: input file path
// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav
// return 0 on success
int ffmpeg_decode_audio(const std::string &ifname, std::vector<uint8_t>& owav_data) {
LOG("ffmpeg_decode_audio: %s\n", ifname.c_str());
int ifd = open(ifname.c_str(), O_RDONLY);
if (ifd == -1) {
fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str());
return -1;
}
u8 *ibuf = NULL;
size_t ibuf_size;
int err = map_file(ifd, &ibuf, &ibuf_size);
if (err) {
LOG("Couldn't map input file %s\n", ifname.c_str());
return err;
}
LOG("Mapped input file: %x size: %d\n", ibuf, ibuf_size);
struct audio_buffer inaudio_buf;
inaudio_buf.ptr = ibuf;
inaudio_buf.size = ibuf_size;
s16 *odata=NULL;
int osize=0;
err = decode_audio(&inaudio_buf, &odata, &osize);
LOG("decode_audio returned %d \n", err);
if (err != 0) {
LOG("decode_audio failed\n");
return err;
}
LOG("decode_audio output size: %d\n", osize);
wave_hdr wh;
const size_t outdatasize = osize * sizeof(s16);
set_wave_hdr(wh, outdatasize);
owav_data.resize(sizeof(wave_hdr) + outdatasize);
// header:
memcpy(owav_data.data(), &wh, sizeof(wave_hdr));
// the data:
memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16));
return 0;
}

View File

@ -26,7 +26,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -70,7 +69,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -102,7 +100,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -184,7 +181,6 @@ json unguided_transcription(struct whisper_context * ctx, audio_async &audio, js
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.suppress_non_speech_tokens = true; wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass // run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
@ -223,7 +219,6 @@ json guided_transcription(struct whisper_context * ctx, audio_async &audio, cons
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// TODO: Do some time testing. Does an overly long prompt slow down processing? // TODO: Do some time testing. Does an overly long prompt slow down processing?
// Set up command sets/precompute prompts // Set up command sets/precompute prompts

View File

@ -3,4 +3,4 @@ add_executable(${TARGET} main.cpp)
include(DefaultTargetOptions) include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common whisper ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})

View File

@ -47,7 +47,6 @@ struct whisper_params {
float temperature = 0.0f; float temperature = 0.0f;
float temperature_inc = 0.2f; float temperature_inc = 0.2f;
bool speed_up = false;
bool debug_mode = false; bool debug_mode = false;
bool translate = false; bool translate = false;
bool detect_language = false; bool detect_language = false;
@ -138,7 +137,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -206,7 +204,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
@ -1106,7 +1103,6 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word; wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode; wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

View File

@ -61,7 +61,6 @@ struct whisper_params {
float temperature = 0.00f; float temperature = 0.00f;
float temperature_inc = 0.20f; float temperature_inc = 0.20f;
bool speed_up = false;
bool debug_mode = false; bool debug_mode = false;
bool translate = false; bool translate = false;
bool detect_language = false; bool detect_language = false;
@ -112,7 +111,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
@ -159,7 +157,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -768,7 +765,6 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word; wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode; wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
@ -947,7 +943,7 @@ int main(int argc, char ** argv) {
"application/json"); "application/json");
} }
// reset params to thier defaults // reset params to their defaults
params = default_params; params = default_params;
}); });
svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){

View File

@ -27,7 +27,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool no_fallback = false; bool no_fallback = false;
bool print_special = false; bool print_special = false;
@ -62,7 +61,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
@ -100,7 +98,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
@ -314,7 +311,6 @@ int main(int argc, char ** argv) {
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

View File

@ -59,7 +59,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -100,7 +99,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); } else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -149,7 +147,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers); fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -205,7 +202,6 @@ std::string transcribe(
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return ""; return "";

View File

@ -26,7 +26,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -60,7 +59,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -96,7 +94,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -132,7 +129,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return ""; return "";

View File

@ -26,7 +26,6 @@ struct whisper_params {
float grammar_penalty = 100.0f; float grammar_penalty = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -57,7 +56,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -89,7 +87,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }

View File

@ -75,7 +75,7 @@ static __global__ void mul_mat_vec_q(
tmp[j][i] = warp_reduce_sum(tmp[j][i]); tmp[j][i] = warp_reduce_sum(tmp[j][i]);
} }
if (threadIdx.x < rows_per_cuda_block) { if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
} }
} }

3
samples/.gitignore vendored
View File

@ -1 +1,4 @@
* *
!jfk.wave
!jfk.mp3

BIN
samples/jfk.mp3 Normal file

Binary file not shown.

View File

@ -74,3 +74,14 @@ add_test(NAME ${TEST_TARGET}
-m ${PROJECT_SOURCE_DIR}/models/for-tests-ggml-large.bin -m ${PROJECT_SOURCE_DIR}/models/for-tests-ggml-large.bin
-f ${PROJECT_SOURCE_DIR}/samples/jfk.wav) -f ${PROJECT_SOURCE_DIR}/samples/jfk.wav)
set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large") set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large")
if (WHISPER_FFMPEG)
set(TEST_TARGET test-main-tiny-mp3)
# Check with reviewers: any way to check the output transcription via ctest (diff, ...)?
add_test(NAME ${TEST_TARGET}
COMMAND $<TARGET_FILE:main>
-m ${PROJECT_SOURCE_DIR}/models/for-tests-ggml-tiny.en.bin
-f ${PROJECT_SOURCE_DIR}/samples/jfk.mp3)
set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3")
endif()

363
whisper-mel-cuda.cu Normal file
View File

@ -0,0 +1,363 @@
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
#include "whisper-mel-cuda.hpp"
#include "whisper.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cufft.h>
#include <cublas_v2.h>
#include <cuComplex.h>
#include <cub/device/device_reduce.cuh>
#include <device_launch_parameters.h>
#include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4324) // added padding
#endif
#ifndef NDEBUG
# define DO_CHECKS 1
#else
# define DO_CHECKS 0
#endif
namespace {
#if DO_CHECKS
const char* cufftGetErrorString(cufftResult_t res) {
switch (res) {
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
case CUFFT_INVALID_TYPE: return "No longer used";
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
case CUFFT_UNALIGNED_DATA: return "No longer used";
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
case CUFFT_PARSE_ERROR: return "Internal plan database error";
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
default: return "Unknown error";
}
}
# define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#else
# define CUDA_CHECK_GEN(err, success, error_fn) err
#endif
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
__global__ void k_fill_stft_input(
const float * padded_samples,
const int n_frames,
const float * hann_window,
float * stft_in
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT) return;
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
auto outLine = stft_in + y * WHISPER_N_FFT;
outLine[x] = line[x] * hann_window[x];
}
__global__ void k_calc_magnitudes(
const cuComplex* stft_out,
const int n_frames,
float * magnitudes
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT_HALF) return;
auto idx = y * WHISPER_N_FFT_HALF + x;
auto r = stft_out[idx].x;
auto i = stft_out[idx].y;
magnitudes[idx] = r * r + i * i;
}
__global__ void k_calc_log_mel(
const float * mel_data,
const int n_mel,
const float * max_val,
float * log_mel
) {
auto x = blockIdx.x * blockDim.x + threadIdx.x;
if (x >= n_mel) return;
float val = mel_data[x];
constexpr float e = 1e-10f;
if (val < e) val = e;
val = log10(val);
const float max = log10(*max_val) - 8.f;
if (val < max) val = max;
log_mel[x] = (val + 4) / 4;
}
void fill_stft_input(
const float * padded_samples,
int n_frames,
const float * hann_window,
float * stft_in,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT, 1);
dim3 grid(1, n_frames);
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
}
void calc_magnitudes(
const cuComplex* stft_out,
int n_frames,
float * magnitudes,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT_HALF, 1);
dim3 grid(1, n_frames);
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
}
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
void calc_log_mel(
const float * mel_data,
int n_mel,
void * tempStorage,
int tempStorageSize,
float * log_mel,
cudaStream_t stream
) {
float * max_val = reinterpret_cast<float *>(tempStorage);
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
int block = 256;
int grid = (n_mel + block - 1) / block;
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
}
class mel_calc_cuda : public whisper_mel_calc {
const int m_n_mel;
ggml_backend_t m_backend = nullptr;
cudaStream_t m_stream = nullptr;
cublasHandle_t m_cublas_handle = nullptr;
float * m_hann_window = nullptr;
float * m_filters = nullptr;
// max samples for which we have allocated memory for the temp working areas below (cufft, log_mel)
int m_n_max_samples = 0;
size_t m_cufft_workspace_size = 0;
void * m_cufft_workspace = nullptr;
size_t m_log_mel_temp_storage_size = 0;
void * m_log_mel_temp_storage = nullptr;
public:
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
: m_n_mel(filters.n_mel)
, m_backend(backend)
{
if (filters.n_fft != WHISPER_N_FFT_HALF) {
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
}
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
CUDA_CHECK(cudaStreamCreate(&m_stream));
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
// create Hann window
{
auto hw = whisper_mel_calc::hann_window();
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// fill filters
{
auto& f = filters.data;
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// preallocate working areas enough for the most common cases (<= 30s)
ensure_working_areas(WHISPER_N_SAMPLES);
}
~mel_calc_cuda() {
CUDA_CHECK(cudaStreamSynchronize(m_stream));
CUDA_CHECK(cudaStreamDestroy(m_stream));
CUDA_CHECK(cudaFree(m_hann_window));
CUDA_CHECK(cudaFree(m_cufft_workspace));
CUDA_CHECK(cudaFree(m_filters));
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
}
void ensure_working_areas(int n_samples) {
if (n_samples <= m_n_max_samples) {
return;
}
const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT;
const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
// cufft workspace
{
if (m_cufft_workspace) {
CUDA_CHECK(cudaFree(m_cufft_workspace));
m_cufft_workspace_size = 0;
m_cufft_workspace = nullptr;
}
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size));
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
}
// device reduce working area
{
if (m_log_mel_temp_storage) {
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
m_log_mel_temp_storage_size = 0;
m_log_mel_temp_storage = nullptr;
}
const auto max_mels = 160;
size_t nbytes = 0;
float* temp = nullptr;
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels);
m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE;
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
}
m_n_max_samples = n_samples;
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
ensure_working_areas(samples.len);
const size_t mirror_pad = WHISPER_N_FFT / 2;
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
// pad
std::vector<float> padded_samples(padded_size);
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
// fill the rest of the data
// it should canonically be mirrored at the end as well,
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
float * cu_padded_samples = nullptr;
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
float * stft_in = nullptr; // contiguous buffer for stft input
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
cufftComplex* stft_out;
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
cufftHandle plan;
CUFFT_CHECK(cufftCreate(&plan));
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
{
size_t waSize;
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
assert(waSize <= m_cufft_workspace_size);
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
CUFFT_CHECK(cufftSetStream(plan, m_stream));
}
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
const auto n_mag_frames = n_frames - 1; // drop last frame
float * magnitudes;
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
float * mel_data = nullptr;
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
const float fone = 1.0f, fzero = 0.0f;
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
&fone,
magnitudes, WHISPER_N_FFT_HALF,
m_filters, WHISPER_N_FFT_HALF,
&fzero,
mel_data, int(n_mag_frames)));
whisper_mel ret;
// Calculate semi-padded sample length to ensure compatibility
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
calc_log_mel(
mel_data, int(m_n_mel * n_mag_frames),
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
log_mels, m_stream);
CUDA_CHECK(cudaStreamSynchronize(m_stream));
// cleanup
CUFFT_CHECK(cufftDestroy(plan));
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
return ret;
}
};
}
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
if (filters.n_fft != WHISPER_N_FFT_HALF) {
return nullptr;
}
return new mel_calc_cuda(backend, filters);
}

3
whisper-mel-cuda.hpp Normal file
View File

@ -0,0 +1,3 @@
#include "whisper-mel.hpp"
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters);

34
whisper-mel.hpp Normal file
View File

@ -0,0 +1,34 @@
#pragma once
#include "ggml-backend.h"
#include <vector>
struct whisper_mel {
int n_len_org = 0;
ggml_context * ctx = nullptr;
ggml_tensor * tensor = nullptr;
ggml_backend_buffer_t buffer = nullptr;
};
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
void whisper_mel_free(whisper_mel & mel);
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
template <typename T>
struct whisper_span {
T * data;
int len;
};
struct whisper_mel_calc {
virtual ~whisper_mel_calc();
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) = 0;
static whisper_span<const float> hann_window();
};

View File

@ -10,6 +10,7 @@
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
#include "ggml-cuda.h" #include "ggml-cuda.h"
#include "whisper-mel-cuda.hpp"
#endif #endif
#ifdef GGML_USE_SYCL #ifdef GGML_USE_SYCL
@ -24,6 +25,8 @@
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "whisper-mel.hpp"
#include <atomic> #include <atomic>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@ -380,21 +383,6 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head); static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
struct whisper_vocab { struct whisper_vocab {
using id = int32_t; using id = int32_t;
using token = std::string; using token = std::string;
@ -813,11 +801,15 @@ struct whisper_state {
whisper_kv_cache kv_pad; whisper_kv_cache kv_pad;
whisper_mel mel; whisper_mel mel;
whisper_mel_calc * mel_calc = nullptr;
whisper_mel_calc * mel_calc_fallback = nullptr;
whisper_batch batch; whisper_batch batch;
whisper_decoder decoders[WHISPER_MAX_DECODERS]; whisper_decoder decoders[WHISPER_MAX_DECODERS];
ggml_backend_t backend = nullptr;
// ggml-alloc: // ggml-alloc:
// - stores meta info about the intermediate tensors into the `meta` buffers // - stores meta info about the intermediate tensors into the `meta` buffers
// - stores the actual tensor data into the `data` buffers // - stores the actual tensor data into the `data` buffers
@ -831,7 +823,6 @@ struct whisper_state {
struct ggml_tensor * embd_enc = nullptr; struct ggml_tensor * embd_enc = nullptr;
// helpers for GPU offloading // helpers for GPU offloading
std::vector<float> inp_mel;
std::vector<float> inp_mask; std::vector<float> inp_mask;
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
@ -902,7 +893,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
BYTESWAP_VALUE(dest); BYTESWAP_VALUE(dest);
} }
static bool kv_cache_init( static bool whisper_kv_cache_init(
struct whisper_kv_cache & cache, struct whisper_kv_cache & cache,
ggml_backend_t backend, ggml_backend_t backend,
ggml_type wtype, ggml_type wtype,
@ -945,7 +936,7 @@ static bool kv_cache_init(
return true; return true;
} }
static void kv_cache_free(struct whisper_kv_cache & cache) { static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
ggml_free(cache.ctx); ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer); ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr; cache.ctx = nullptr;
@ -1259,9 +1250,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
} }
#endif #endif
GGML_UNUSED(params);
if (backend_gpu) { if (backend_gpu) {
return backend_gpu; return backend_gpu;
} }
return ggml_backend_cpu_init(); return ggml_backend_cpu_init();
} }
@ -1823,7 +1817,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
static struct ggml_cgraph * whisper_build_graph_conv( static struct ggml_cgraph * whisper_build_graph_conv(
whisper_context & wctx, whisper_context & wctx,
whisper_state & wstate) { whisper_state & wstate,
const int mel_offset) {
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -1842,9 +1837,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); ggml_tensor * mel_inp = wstate.mel.tensor;
ggml_set_name(mel, "mel"); ggml_tensor * mel;
ggml_set_input(mel); if (mel_inp) {
const int n_len = int(mel_inp->ne[0]);
const int out_s = 2 * n_ctx;
const int i0 = std::min(mel_offset, n_len);
const int i1 = std::min(mel_offset + out_s, n_len);
const int mel_s = i1 - i0;
assert(mel_inp->type == GGML_TYPE_F32);
assert(mel_inp->ne[1] == n_mels);
ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));
if (mel_s < out_s) {
mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
}
else {
mel = ggml_cont(ctx0, cur);
}
}
else {
// just create some tensor so that the graph/buffer size estimation is correct
mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
}
ggml_set_name(mel, "mel"); // used with external encoding
struct ggml_tensor * cur = nullptr; struct ggml_tensor * cur = nullptr;
@ -2226,45 +2244,21 @@ static bool whisper_encode_internal(
{ {
auto & alloc = wstate.alloc_conv.alloc; auto & alloc = wstate.alloc_conv.alloc;
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate); ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
if (!ggml_gallocr_alloc_graph(alloc, gf)) { if (!ggml_gallocr_alloc_graph(alloc, gf)) {
// should never happen as we pre-allocate the memory // should never happen as we pre-allocate the memory
return false; return false;
} }
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
// set the input
{
const auto & mel_inp = wstate.mel;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
assert(mel->type == GGML_TYPE_F32);
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
wstate.inp_mel.resize(ggml_nelements(mel));
float * dst = wstate.inp_mel.data();
memset(dst, 0, ggml_nbytes(mel));
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
} }
if (!whisper_encode_external(wstate)) { if (whisper_encode_external(wstate)) {
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
return false; assert(mel->ne[1] == wctx.model.hparams.n_mels);
} GGML_UNUSED(mel);
} else {
#if defined(WHISPER_USE_COREML) #if defined(WHISPER_USE_COREML)
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
#elif defined(WHISPER_USE_OPENVINO) #elif defined(WHISPER_USE_OPENVINO)
@ -2284,7 +2278,7 @@ static bool whisper_encode_internal(
return false; return false;
} }
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false; return false;
} }
} }
@ -2300,7 +2294,7 @@ static bool whisper_encode_internal(
return false; return false;
} }
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false; return false;
} }
} }
@ -2801,7 +2795,7 @@ static bool whisper_decode_internal(
logits = gf->nodes[gf->n_nodes - 1]; logits = gf->nodes[gf->n_nodes - 1];
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) { if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false; return false;
} }
} }
@ -2855,20 +2849,70 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
} }
#define SIN_COS_N_COUNT WHISPER_N_FFT #define SIN_COS_N_COUNT WHISPER_N_FFT
static float sin_vals[SIN_COS_N_COUNT]; namespace {
static float cos_vals[SIN_COS_N_COUNT]; struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];
// In FFT, we frequently use sine and cosine operations with the same values. // Hann window (Use cosf to eliminate difference)
// We can use precalculated values to speed up the process. // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
static void fill_sin_cos_table() { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
static bool is_filled = false; float hann_window[WHISPER_N_FFT];
if (is_filled) return;
for (int i = 0; i < SIN_COS_N_COUNT; i++) { whisper_global_cache() {
double theta = (2*M_PI*i)/SIN_COS_N_COUNT; fill_sin_cos_table();
sin_vals[i] = sinf(theta); fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
cos_vals[i] = cosf(theta);
} }
is_filled = true;
void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
} global_cache;
}
// Mel spectrogram
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel);
mel.n_len_org = n_len_org;
assert(!mel.ctx);
mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel);
mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend));
auto alloc = ggml_tallocr_new(mel.buffer);
ggml_tallocr_alloc(&alloc, mel.tensor);
}
void whisper_mel_free(whisper_mel & mel) {
ggml_free(mel.ctx);
ggml_backend_buffer_free(mel.buffer);
mel.n_len_org = 0;
mel.ctx = nullptr;
mel.tensor = nullptr;
mel.buffer = nullptr;
}
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
whisper_span<const float> whisper_mel_calc::hann_window() {
return {global_cache.hann_window, WHISPER_N_FFT};
} }
// naive Discrete Fourier Transform // naive Discrete Fourier Transform
@ -2886,8 +2930,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*cos_vals[idx]; // cos(t) re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*sin_vals[idx]; // sin(t) im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
} }
out[k*2 + 0] = re; out[k*2 + 0] = re;
@ -2938,8 +2982,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
const int sin_cos_step = SIN_COS_N_COUNT / N; const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N/2; k++) { for (int k = 0; k < N/2; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = cos_vals[idx]; // cos(t) float re = global_cache.cos_vals[idx]; // cos(t)
float im = -sin_vals[idx]; // sin(t) float im = -global_cache.sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0]; float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1]; float im_odd = odd_fft[2*k + 1];
@ -2952,24 +2996,20 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
} }
} }
static bool hann_window(int length, bool periodic, std::vector<float> & output) { namespace {
if (output.size() < static_cast<size_t>(length)) {
output.resize(length);
}
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
}
return true; struct whisper_mel_data {
} int n_len;
int n_len_org;
int n_mel;
float * data;
};
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples, void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads, int n_samples, int n_threads,
const whisper_filters & filters, whisper_mel & mel) { const whisper_filters & filters, whisper_mel_data & mel) {
const auto frame_size = WHISPER_N_FFT;
const auto frame_step = WHISPER_HOP_LENGTH;
std::vector<float> fft_in(frame_size, 0.0); std::vector<float> fft_in(frame_size, 0.0);
std::vector<float> fft_out(2 * frame_size); std::vector<float> fft_out(2 * frame_size);
int n_fft = filters.n_fft; int n_fft = filters.n_fft;
@ -2982,7 +3022,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step; const int offset = i * frame_step;
// apply Hanning window (~10% faster) // apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j]; fft_in[j] = hann[j] * samples[offset + j];
} }
@ -3034,101 +3074,109 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
} }
} }
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 struct mel_calc_cpu : public whisper_mel_calc {
static bool log_mel_spectrogram( ggml_backend_t m_backend;
whisper_state & wstate, const whisper_filters & m_filters;
const float * samples, mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();
// Hanning window (Use cosf to eliminate difference) // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) override {
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 // Hann window
std::vector<float> hann; const float * hann = global_cache.hann_window;
hann_window(frame_size, true, hann);
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = WHISPER_N_FFT / 2;
// Calculate the length of padding const int n_samples = int(ssamples.len);
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; const float * samples = ssamples.data;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it. // Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded; std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio // reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel; whisper_mel_data mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 mel.n_mel = m_filters.n_mel;
// Calculate number of frames + remove the last frame // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
mel.n_len = (samples_padded.size() - frame_size) / frame_step; // Calculate number of frames + remove the last frame
// Calculate semi-padded sample length to ensure compatibility mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; // Calculate semi-padded sample length to ensure compatibility
mel.data.resize(mel.n_mel * mel.n_len); mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
std::vector<float> host_mel_data;
{ whisper_mel ret;
std::vector<std::thread> workers(n_threads - 1); whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
for (int iw = 0; iw < n_threads - 1; ++iw) { if (ggml_backend_buffer_is_host(ret.buffer)) {
workers[iw] = std::thread( mel.data = reinterpret_cast<float*>(ret.tensor->data);
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, } else {
n_samples + stage_2_pad, frame_size, frame_step, n_threads, host_mel_data.resize(mel.n_len * mel.n_mel);
std::cref(filters), std::ref(mel)); mel.data = host_mel_data.data();
} }
// main thread {
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
n_samples + stage_2_pad, n_threads,
std::cref(m_filters), std::ref(mel));
}
for (int iw = 0; iw < n_threads - 1; ++iw) { // main thread
workers[iw].join(); log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, n_threads, m_filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
} }
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
if (!host_mel_data.empty()) {
// the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
}
return ret;
} }
};
}
// clamping and normalization whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters) {
double mmax = -1e20; #if GGML_USE_CUDA
for (int i = 0; i < mel.n_mel*mel.n_len; i++) { if (ggml_backend_is_cuda(backend)) {
if (mel.data[i] > mmax) { auto ret = whisper_mel_calc_create_cuda(backend, filters);
mmax = mel.data[i]; // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
} const float warmup[256] = {0};
} ret->calculate({warmup, 256}, 1);
return ret;
mmax -= 8.0; } else
#endif
for (int i = 0; i < mel.n_mel*mel.n_len; i++) { return new mel_calc_cpu(backend, filters);
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
wstate.t_mel_us += ggml_time_us() - t_start_us;
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}
return true;
} }
// split text into tokens // split text into tokens
@ -3244,19 +3292,26 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
#endif #endif
struct whisper_state * whisper_init_state(whisper_context * ctx) { struct whisper_state * whisper_init_state(whisper_context * ctx) {
fill_sin_cos_table();
whisper_state * state = new whisper_state; whisper_state * state = new whisper_state;
state->backend = whisper_backend_init(ctx->params);
if (!state->backend) {
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
whisper_free_state(state);
return nullptr;
}
state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
// in theory, there can be a case where this is not enough, but in practice it should always be enough // in theory, there can be a case where this is not enough, but in practice it should always be enough
const int factor = 3; const int factor = 3;
if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer, ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
} }
@ -3266,11 +3321,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
} }
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer, ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
} }
@ -3280,11 +3335,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
} }
if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
ctx->model.hparams.n_audio_state, ctx->model.hparams.n_audio_state,
1, 1,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
} }
@ -3296,7 +3351,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// [EXPERIMENTAL] Token-level timestamps with DTW // [EXPERIMENTAL] Token-level timestamps with DTW
if (ctx->params.dtw_token_timestamps) { if (ctx->params.dtw_token_timestamps) {
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
@ -3339,9 +3394,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// conv allocator // conv allocator
{ {
bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
[&]() { [&]() {
return whisper_build_graph_conv(*ctx, *state); return whisper_build_graph_conv(*ctx, *state, 0);
}); });
if (!ok) { if (!ok) {
@ -3355,7 +3410,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// encoder allocator // encoder allocator
if (!whisper_encode_external(*state)) { if (!whisper_encode_external(*state)) {
bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
[&]() { [&]() {
return whisper_build_graph_encoder(*ctx, *state); return whisper_build_graph_encoder(*ctx, *state);
}); });
@ -3371,7 +3426,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// cross allocator // cross allocator
{ {
bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
[&]() { [&]() {
return whisper_build_graph_cross(*ctx, *state); return whisper_build_graph_cross(*ctx, *state);
}); });
@ -3387,7 +3442,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// decoder allocator // decoder allocator
{ {
bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
[&]() { [&]() {
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
@ -3659,9 +3714,16 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
void whisper_free_state(struct whisper_state * state) { void whisper_free_state(struct whisper_state * state) {
if (state) { if (state) {
kv_cache_free(state->kv_self); whisper_kv_cache_free(state->kv_self);
kv_cache_free(state->kv_cross); whisper_kv_cache_free(state->kv_cross);
kv_cache_free(state->kv_pad); whisper_kv_cache_free(state->kv_pad);
whisper_mel_free(state->mel);
delete state->mel_calc;
state->mel_calc = nullptr;
delete state->mel_calc_fallback;
state->mel_calc_fallback = nullptr;
#ifdef WHISPER_USE_COREML #ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) { if (state->ctx_coreml != nullptr) {
@ -3684,6 +3746,8 @@ void whisper_free_state(struct whisper_state * state) {
ggml_gallocr_free(state->alloc_cross.alloc); ggml_gallocr_free(state->alloc_cross.alloc);
ggml_gallocr_free(state->alloc_decode.alloc); ggml_gallocr_free(state->alloc_decode.alloc);
ggml_backend_free(state->backend);
// [EXPERIMENTAL] Token-level timestamps with DTW // [EXPERIMENTAL] Token-level timestamps with DTW
aheads_masks_free(state->aheads_masks); aheads_masks_free(state->aheads_masks);
@ -3718,11 +3782,37 @@ void whisper_free_params(struct whisper_full_params * params) {
} }
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { const int64_t t_start_us = ggml_time_us();
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1; whisper_mel_free(state->mel);
if (n_samples <= 5 * 60 * WHISPER_SAMPLE_RATE) {
// calculate mel spectrogram for lengths up to 5 minutes on the most optimal mel calculator
state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads);
} else {
// calcuate mel spectrogram for longer audios on the CPU
// 1. gpu calculations may use hundreds of megabytes of memory for longer audios so we're being conservative
// with our gpu demands
// 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
// taking longer is not a major concern
if (!state->mel_calc_fallback) {
state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters);
}
state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
} }
state->t_mel_us += ggml_time_us() - t_start_us;
// Dump log_mel_spectrogram
//{
// auto& mel = state->mel;
// std::ofstream outFile("log_mel_spectrogram.json");
// outFile << "[";
// for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
// outFile << mel.data[i] << ", ";
// }
// outFile << mel.data[mel.data.size() - 1] << "]";
// outFile.close();
//}
return 0; return 0;
} }
@ -3730,30 +3820,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
} }
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
}
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
// TODO
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
// TODO
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
// TODO
int whisper_set_mel_with_state( int whisper_set_mel_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -3765,12 +3831,10 @@ int whisper_set_mel_with_state(
return -1; return -1;
} }
state->mel.n_len = n_len; whisper_mel_free(state->mel);
state->mel.n_len_org = n_len; whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
state->mel.n_mel = n_mel;
state->mel.data.resize(n_len*n_mel); ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
return 0; return 0;
} }
@ -4654,7 +4718,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.split_on_word =*/ false, /*.split_on_word =*/ false,
/*.max_tokens =*/ 0, /*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.debug_mode =*/ false, /*.debug_mode =*/ false,
/*.audio_ctx =*/ 0, /*.audio_ctx =*/ 0,
@ -5328,15 +5391,9 @@ int whisper_full_with_state(
if (n_samples > 0) { if (n_samples > 0) {
// compute log mel spectrogram // compute log mel spectrogram
if (params.speed_up) { if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
// TODO: Replace PV with more advanced algorithm
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
return -1; return -2;
} else {
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
return -2;
}
} }
} }
@ -5373,7 +5430,7 @@ int whisper_full_with_state(
// if length of spectrogram is less than 1.0s (100 frames), then return // if length of spectrogram is less than 1.0s (100 frames), then return
// basically don't process anything that is less than 1.0s // basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { if (seek_end < seek_start + 100) {
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10); WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
return 0; return 0;
} }
@ -6085,8 +6142,8 @@ int whisper_full_with_state(
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) { if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0; const auto tt0 = t0;
const auto tt1 = params.speed_up ? 2*t1 : t1; const auto tt1 = t1;
if (params.print_realtime) { if (params.print_realtime) {
if (params.print_timestamps) { if (params.print_timestamps) {
@ -6132,8 +6189,8 @@ int whisper_full_with_state(
if (!text.empty()) { if (!text.empty()) {
const auto t1 = seek + seek_delta; const auto t1 = seek + seek_delta;
const auto tt0 = params.speed_up ? 2*t0 : t0; const auto tt0 = t0;
const auto tt1 = params.speed_up ? 2*t1 : t1; const auto tt1 = t1;
if (params.print_realtime) { if (params.print_realtime) {
if (params.print_timestamps) { if (params.print_timestamps) {
@ -7224,7 +7281,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
// operation (after median filter) // operation (after median filter)
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_norm(gctx, w, 1e-9); w = ggml_norm(gctx, w, 1e-9f);
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
// Pass median filter - this is done over AUDIO_TOKENS dimension. // Pass median filter - this is done over AUDIO_TOKENS dimension.

View File

@ -31,8 +31,10 @@
#define WHISPER_SAMPLE_RATE 16000 #define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400 #define WHISPER_N_FFT 400
#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1)
#define WHISPER_HOP_LENGTH 160 #define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30 #define WHISPER_CHUNK_SIZE 30
#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -266,22 +268,6 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
@ -499,7 +485,6 @@ extern "C" {
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output // note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
int audio_ctx; // overwrite the audio context size (0 = use default) int audio_ctx; // overwrite the audio context size (0 = use default)