Compare commits

..

21 Commits

Author SHA1 Message Date
bff8dc248a talk-llama : sync llama.cpp
ggml-ci
2025-05-13 13:20:19 +03:00
69753804ed whisper : update to ggml-backend changes (#0)
ggml-ci
2025-05-13 13:11:24 +03:00
89970b9aaa sync : ggml
ggml-ci
2025-05-13 13:10:17 +03:00
79fb43e252 ggml : add mrope kernel for metal (llama/13457) 2025-05-13 13:10:08 +03:00
926e06dbfd metal : optimize MoE for large batches (llama/13388) 2025-05-13 13:09:20 +03:00
43a59eccf6 opencl: remove unnecessary assert for add (llama/13257) 2025-05-13 13:05:33 +03:00
fe0d52b9a2 llama/ggml: add LLM training support (llama/10544)
* llama/ggml: add LLM training support

more compact progress bar

llama_save_model_to_file

llama_opt_param_filter

ggml_graph_dup force_grads

refactor ggml_opt, fix test-opt

* remove logits_all

* refactor CUDA implementation for ACC

* reset graph at beginning of opt period
2025-05-13 13:05:33 +03:00
cb90cb0992 ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (llama/13053)
* ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel

Signed-off-by: Dan Johansson <dan.johansson@arm.com>

* * code review fixes

Signed-off-by: Dan Johansson <dan.johansson@arm.com>

* * adds a comment that clarifies barrier usage

Signed-off-by: Dan Johansson <dan.johansson@arm.com>

---------

Signed-off-by: Dan Johansson <dan.johansson@arm.com>
Co-authored-by: Charles Xu <charles.xu@arm.com>
2025-05-13 13:05:33 +03:00
8264872b5d CUDA: fix misaligned synchronization in FA (llama/13469) 2025-05-13 13:05:33 +03:00
882d975729 enable dpcpp nightly builds with libraries (llama/13406) 2025-05-13 13:05:33 +03:00
c426829771 CUDA: fix crash with partial offloading of MoE (llama/13439) 2025-05-13 13:05:33 +03:00
0b1962a181 Add --no-op-offload to improve -ot pp perf in MoE models like llama4 400B (llama/13386) 2025-05-13 13:05:33 +03:00
86dece9c7c CUDA: fix race conditions FlashAttention kernels (llama/13438) 2025-05-13 13:05:32 +03:00
04445664b4 CUDA: fix FlashAttention on Turing (llama/13415) 2025-05-13 13:05:32 +03:00
22f4997dd8 vulkan: scalar flash attention implementation (llama/13324)
* vulkan: scalar flash attention implementation

* vulkan: always use fp32 for scalar flash attention

* vulkan: use vector loads in scalar flash attention shader

* vulkan: remove PV matrix, helps with register usage

* vulkan: reduce register usage in scalar FA, but perf may be slightly worse

* vulkan: load each Q value once. optimize O reduction. more tuning

* vulkan: support q4_0/q8_0 KV in scalar FA

* CI: increase timeout to accommodate newly-supported tests

* vulkan: for scalar FA, select between 1 and 8 rows

* vulkan: avoid using Float16 capability in scalar FA
2025-05-13 13:05:32 +03:00
b493e03b90 sycl : implementation of reordered Q4_0 MMVQ for Intel GPUs (llama/12858)
* sycl : Implemented reorder Q4_0 mmvq

Signed-off-by: Alberto Cabrera <alberto.cabrera@codeplay.com>

* sycl : Fixed mmvq being called when reorder is disabled

* sycl : Improved comments in the quants header

Signed-off-by: Alberto Cabrera <alberto.cabrera@codeplay.com>

* Use static_assert

* safe_div -> ceil_div

* Clarify qi comment

* change the reorder tensor from init to execute OP

* dbg

* Undo changes to test-backend-ops

* Refactor changes on top of q4_0 reorder fix

* Missing Reverts

* Refactored opt_for_reorder logic to simplify code path

* Explicit inlining and unroll

* Renamed mul_mat_algo enum for consistency

---------

Signed-off-by: Alberto Cabrera <alberto.cabrera@codeplay.com>
Co-authored-by: romain.biessy <romain.biessy@codeplay.com>
2025-05-13 13:05:32 +03:00
aef59f4851 CUDA: FA support for Deepseek (Ampere or newer) (llama/13306)
* CUDA: FA support for Deepseek (Ampere or newer)

* do loop unrolling via C++ template
2025-05-13 13:05:32 +03:00
f8c75dc43e CUDA: fix crash on large batch size for MoE models (llama/13384) 2025-05-13 13:05:32 +03:00
00c8056715 rpc : add rpc_msg_set_tensor_hash_req (llama/13353)
* rpc : add rpc_msg_set_tensor_hash_req

Use a dedicated struct for the request of RPC_CMD_SET_TENSOR_HASH which
makes the code cleaner.

* fix
2025-05-13 13:05:32 +03:00
19d8d9a928 vulkan: Allow up to 4096 elements for mul_mat_id row_ids (llama/13326)
This assert fired running Qwen_Qwen3-30B-A3B-Q2_K.gguf:

GGML_ASSERT(nei0 * nei1 <= 3072);

The tensor is 8 x 512. Increase this array size to accommodate.
2025-05-13 13:05:32 +03:00
0c4a229154 sycl: addressing non-contiguous src1 mul_mats (nc and batched) (llama/13343)
* sycl: fixed non-contiguous src1 mul_mats (nc and batched)

* Fixed wrong static_cast inside kernel
2025-05-13 13:05:31 +03:00
434 changed files with 38999 additions and 71259 deletions

View File

@ -16,7 +16,6 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# Ref: https://stackoverflow.com/a/53464012
@ -27,12 +26,6 @@ COPY .. .
# Enable cuBLAS
RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1"
RUN find /app/build -name "*.o" -delete && \
find /app/build -name "*.a" -delete && \
rm -rf /app/build/CMakeFiles && \
rm -rf /app/build/cmake_install.cmake && \
rm -rf /app/build/_deps
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
@ -40,11 +33,8 @@ WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
RUN du -sh /app/*
RUN find /app -type f -size +100M
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,28 +0,0 @@
ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04
FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
# Enable SYCL
ARG GGML_SYCL_F16=OFF
RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
echo "GGML_SYCL_F16 is set" \
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \
fi && \
make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}"
FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,40 +1,29 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc4.2.0
ARG MUSA_VERSION=rc3.1.1
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}-amd64
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the MUSA runtime image
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* /tmp/* /var/tmp/*
apt-get install -y build-essential libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
# Enable muBLAS
RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1"
RUN find /app/build -name "*.o" -delete && \
find /app/build -name "*.a" -delete && \
rm -rf /app/build/CMakeFiles && \
rm -rf /app/build/cmake_install.cmake && \
rm -rf /app/build/_deps
FROM ${BASE_MUSA_RUN_CONTAINER} AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* /tmp/* /var/tmp/*
COPY --from=build /app/build/bin /app/build/bin
COPY --from=build /app/samples /app/samples
COPY --from=build /app/models /app/models
apt-get install -y curl ffmpeg wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -4,27 +4,6 @@ on:
push:
branches:
- master
tags:
- 'v*'
paths: ['.github/workflows/build.yml',
'**/CMakeLists.txt',
'**/Makefile',
'**/*.mk',
'**/*.cmake',
'**/*.in',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.cl',
'**/*.swift',
'**/*.m',
'**/*.mm',
'**/*.metal',
'**/*.comp',
'**/*.java']
pull_request:
types: [opened, synchronize, reopened]
workflow_dispatch:
@ -62,7 +41,6 @@ jobs:
runs-on: ubuntu-latest
outputs:
tag_name: ${{ steps.tag.outputs.name }}
should_release: ${{ steps.tag.outputs.should_release }}
steps:
- name: Checkout with full history
@ -77,7 +55,6 @@ jobs:
BUILD_NUMBER=$(git rev-list --count HEAD)
SHORT_HASH=$(git rev-parse --short=7 HEAD)
CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}"
SHOULD_RELEASE="false"
echo "Raw values:"
echo "BUILD_NUMBER: $BUILD_NUMBER"
@ -85,34 +62,21 @@ jobs:
echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}"
echo "CUSTOM_TAG: $CUSTOM_TAG"
if [[ "${{ github.ref_type }}" == "tag" ]]; then
echo "Using pushed tag name"
TAG_NAME="${{ github.ref_name }}"
SHOULD_RELEASE="true"
elif [[ -n "$CUSTOM_TAG" ]]; then
# Use custom tag if provided
if [[ -n "$CUSTOM_TAG" ]]; then
echo "Using custom tag"
TAG_NAME="${CUSTOM_TAG}"
SHOULD_RELEASE="true"
elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then
echo "Manual release requested"
SHOULD_RELEASE="true"
TAG_NAME="b${BUILD_NUMBER}"
elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "Using master branch format"
TAG_NAME="b${BUILD_NUMBER}"
SHOULD_RELEASE="false"
else
echo "Using non-master branch format"
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}"
SHOULD_RELEASE="false"
fi
echo "Final tag name: $TAG_NAME"
echo "Should release: $SHOULD_RELEASE"
echo "name=$TAG_NAME" >> $GITHUB_OUTPUT
echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT
ubuntu-22:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
@ -137,10 +101,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt update
apt install -y build-essential libsdl2-dev cmake git
cmake -B build
@ -169,14 +129,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update
apt install -y build-essential libsdl2-dev cmake git
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
@ -205,14 +157,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update
apt install -y build-essential libsdl2-dev cmake git
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
@ -298,10 +242,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt update
apt install -y build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
@ -332,14 +272,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update
apt install -y build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
@ -370,14 +302,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update
apt install -y build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
@ -411,14 +335,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update
apt install -y clang build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
@ -449,10 +365,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt update
apt install -y build-essential cmake git
cmake . -DCMAKE_BUILD_TYPE=Debug \
@ -615,7 +527,6 @@ jobs:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest
needs: determine-tag
strategy:
matrix:
@ -699,7 +610,9 @@ jobs:
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip"
- name: Upload binaries
if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }}
if: matrix.sdl2 == 'ON' && ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4
with:
name: whisper-bin-${{ matrix.arch }}.zip
@ -716,14 +629,11 @@ jobs:
arch: [Win32, x64]
blas: [ON]
sdl2: [ON]
blasver: [0.3.29]
include:
- arch: Win32
s2arc: x86
blasfile: x86
- arch: x64
s2arc: x64
blasfile: x64_64
- sdl2: ON
s2ver: 2.28.5
@ -744,8 +654,7 @@ jobs:
- name: Install OpenBLAS and pkgconfiglite
if: matrix.blas == 'ON'
run: |
Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip"
Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}"
vcpkg install --triplet=${{ matrix.s2arc }}-windows openblas
choco install pkgconfiglite
- name: Fetch SDL2 and set SDL2_DIR
@ -762,8 +671,6 @@ jobs:
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DGGML_BLAS=${{ matrix.blas }}
-DGGML_BLAS_VENDOR=OpenBLAS
-DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib"
-DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include"
-DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build
@ -773,7 +680,7 @@ jobs:
- name: Copy openblas.dll
if: matrix.blas == 'ON'
run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }}
run: copy "C:/vcpkg/packages/openblas_${{ matrix.s2arc }}-windows/bin/openblas.dll" build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
@ -785,7 +692,9 @@ jobs:
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip"
- name: Upload binaries
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }}
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4
with:
name: whisper-blas-bin-${{ matrix.arch }}.zip
@ -794,16 +703,14 @@ jobs:
windows-cublas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-2022
needs: determine-tag
runs-on: windows-2019
strategy:
fail-fast: false
matrix:
build: [Release]
arch: [x64]
cublas: [ON]
sdl2: [ON]
cuda-toolkit: [12.4.0, 11.8.0]
cuda-toolkit: [12.2.0, 11.8.0]
include:
- arch: x64
sdl2: ON
@ -871,7 +778,7 @@ jobs:
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
@ -879,23 +786,23 @@ jobs:
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Install Cuda Toolkit 12.4.0
if: ${{ matrix.cuda-toolkit == '12.4.0' }}
- name: Install Cuda Toolkit 12.2.0
if: ${{ matrix.cuda-toolkit == '12.2.0' }}
run: |
$CUDA_VERSION = ${{ matrix.cuda-toolkit }}
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
# Components versions
$CUDART_VER = "12.4.127"
$NVCC_VER = "12.4.131"
$NVRTC_VER = "12.4.127"
$CUBLAS_VER = "12.4.5.8"
$NVTX_VER = "12.4.127"
$PROFILER_VER = "12.4.127"
$VS_VER = "12.4.127"
$NVPROF_VER = "12.4.128"
$CCCL_VER = "12.4.127"
$CUDART_VER = "12.2.140"
$NVCC_VER = "12.2.140"
$NVRTC_VER = "12.2.140"
$CUBLAS_VER = "12.2.5.6"
$NVTX_VER = "12.2.140"
$PROFILER_VER = "12.2.140"
$VS_VER = "12.2.140"
$NVPROF_VER = "12.2.142"
$CCCL_VER = "12.2.140"
# Create the directory where the CUDA Toolkit will be installed
mkdir -p $CUDA_TOOLKIT_DIR
@ -929,7 +836,7 @@ jobs:
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
@ -957,21 +864,14 @@ jobs:
- name: Build Project
shell: cmd
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake --version
where cmake
if "${{ matrix.cuda-toolkit }}" == "11.8.0" (
set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR
) else (
set CUDA_FLAGS=
)
cmake -S . -B build -G "Ninja Multi-Config" ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%" ^
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%"
-DSDL2_DIR="%SDL2_DIR%"
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%
@ -994,7 +894,9 @@ jobs:
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip"
- name: Upload binaries
if: ${{ needs.determine-tag.outputs.should_release }}
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4
with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip
@ -1071,11 +973,16 @@ jobs:
- name: Pack artifacts
id: pack_artifacts
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
run: |
zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework
- name: Upload artifacts
if: ${{ needs.determine-tag.outputs.should_release }}
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4
with:
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
@ -1253,7 +1160,7 @@ jobs:
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
release:
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }}
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' }}
runs-on: ubuntu-latest
@ -1296,7 +1203,6 @@ jobs:
with:
tag_name: ${{ needs.determine-tag.outputs.tag_name }}
prerelease: ${{ github.event.inputs.pre_release_tag != '' }}
draft: true
- name: Upload release
id: upload_release
@ -1323,8 +1229,7 @@ jobs:
coreml-base-en:
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' ||
startsWith(github.ref, 'refs/tags/v') }}
github.event.inputs.pre_release_tag != '' }}
runs-on: macos-latest
needs: determine-tag

View File

@ -15,13 +15,13 @@ jobs:
env:
COMMIT_SHA: ${{ github.sha }}
strategy:
fail-fast: false
matrix:
config:
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
- { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" }
- { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" }
- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
#TODO: the cuda image keeps failing - disable for now
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
steps:
- name: Check out the repo
@ -42,35 +42,21 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Free up disk space
run: |
sudo apt-get remove -y '^dotnet-.*' '^llvm-.*' '^mysql-.*' '^postgresql-.*'
sudo apt-get autoremove -y
sudo apt-get autoclean
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
docker system prune -af
df -h
- name: Generate tags
id: tags
run: |
TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
if [ "${{ github.event_name }}" == "push" ]; then
TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
fi
echo "tags=$TAGS" >> $GITHUB_OUTPUT
- name: Build and push Docker image (versioned)
if: github.event_name == 'push'
uses: docker/build-push-action@v5
with:
context: .
push: true
platforms: ${{ matrix.config.platform }}
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
file: ${{ matrix.config.dockerfile }}
- name: Build and push Docker image (tagged)
uses: docker/build-push-action@v5
uses: docker/build-push-action@v4
with:
context: .
push: ${{ github.event_name == 'push' }}
platforms: ${{ matrix.config.platform }}
tags: ${{ steps.tags.outputs.tags }}
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
file: ${{ matrix.config.dockerfile }}

3
.gitignore vendored
View File

@ -14,7 +14,6 @@
build/
build-*/
build_*/
# SPM
.build/
@ -50,8 +49,6 @@ extra/bench-gg.txt
models/*.mlmodel
models/*.mlmodelc
models/*.mlpackage
models/*-encoder-openvino.xml
models/*-encoder-openvino-cache/
bindings/java/.gradle/
bindings/java/.idea/
.idea/

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.7.6)
project("whisper.cpp" VERSION 1.7.5)
include(CheckIncludeFileCXX)
set(SOVERSION 1)
@ -119,11 +119,6 @@ whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL)
whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16)
whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE)
if (GGML_CUDA AND NOT MSVC)
#GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets
add_compile_options(-Wno-deprecated-gpu-targets)
endif()
#
# build the library
#
@ -178,10 +173,6 @@ get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h)
install(TARGETS whisper LIBRARY PUBLIC_HEADER)
target_compile_definitions(whisper PRIVATE
WHISPER_VERSION="${PROJECT_VERSION}"
)
configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake
@ -250,6 +241,5 @@ if (MSVC)
disable_msvc_warnings(whisper-talk-llama)
disable_msvc_warnings(whisper-bench)
disable_msvc_warnings(quantize)
disable_msvc_warnings(vad-speech-segments)
endif()
endif()

View File

@ -7,7 +7,7 @@
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.7.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
Stable: [v1.7.5](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.5) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -35,7 +35,7 @@ Supported platforms:
- [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggml-org/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168))
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggml-org/whisper.cpp/discussions/166)
- [x] [Docker](https://github.com/ggml-org/whisper.cpp/pkgs/container/whisper.cpp)
@ -80,7 +80,7 @@ Now build the [whisper-cli](examples/cli) example and transcribe an audio file l
```bash
# build the project
cmake -B build
cmake --build build -j --config Release
cmake --build build --config Release
# transcribe an audio file
./build/bin/whisper-cli -f samples/jfk.wav
@ -149,7 +149,7 @@ standard cmake setup with:
```bash
# build with GGML_BLAS defined
cmake -B build -DGGML_BLAS=1
cmake --build build -j --config Release
cmake --build build --config Release
./build/bin/whisper-cli [ .. etc .. ]
```
@ -163,7 +163,7 @@ Here are the steps for creating and using a quantized model:
```bash
# quantize a model with Q5_0 method
cmake -B build
cmake --build build -j --config Release
cmake --build build --config Release
./build/bin/quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
# run the examples as usual, specifying the quantized model file
@ -386,7 +386,7 @@ Run the inference examples as usual, for example:
## Moore Threads GPU support
With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels.
First, make sure you have installed `MUSA SDK rc4.2.0`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=4.2.0
First, make sure you have installed `MUSA SDK rc3.1.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=rc3.1.1
Now build `whisper.cpp` with MUSA support:
@ -489,7 +489,7 @@ You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installe
```bash
cmake -B build -DWHISPER_SDL2=ON
cmake --build build -j --config Release
cmake --build build --config Release
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
@ -709,9 +709,7 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
## XCFramework
The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS,
and macOS. It can be used in Swift projects without the need to compile the
library from source. For example, the v1.7.5 version of the XCFramework can be
used as follows:
library from source. For examples:
```swift
// swift-tools-version: 5.10
// The swift-tools-version declares the minimum version of Swift required to build this package.
@ -735,7 +733,7 @@ let package = Package(
)
```
## Voice Activity Detection (VAD)
### Voice Activity Detection (VAD)
Support for Voice Activity Detection (VAD) can be enabled using the `--vad`
argument to `whisper-cli`. In addition to this option a VAD model is also
required.
@ -749,36 +747,11 @@ transcription process.
The following VAD models are currently supported:
### Silero-VAD
#### Silero-VAD
[Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model
written in Python that is fast and accurate.
Models can be downloaded by running the following command on Linux or MacOS:
```console
$ ./models/download-vad-model.sh silero-v5.1.2
Downloading ggml model silero-v5.1.2 from 'https://huggingface.co/ggml-org/whisper-vad' ...
ggml-silero-v5.1.2.bin 100%[==============================================>] 864.35K --.-KB/s in 0.04s
Done! Model 'silero-v5.1.2' saved in '/path/models/ggml-silero-v5.1.2.bin'
You can now use it like this:
$ ./build/bin/whisper-cli -vm /path/models/ggml-silero-v5.1.2.bin --vad -f samples/jfk.wav -m models/ggml-base.en.bin
```
And the following command on Windows:
```console
> .\models\download-vad-model.cmd silero-v5.1.2
Downloading vad model silero-v5.1.2...
Done! Model silero-v5.1.2 saved in C:\Users\danie\work\ai\whisper.cpp\ggml-silero-v5.1.2.bin
You can now use it like this:
C:\path\build\bin\Release\whisper-cli.exe -vm C:\path\ggml-silero-v5.1.2.bin --vad -m models/ggml-base.en.bin -f samples\jfk.wav
```
To see a list of all available models, run the above commands without any
arguments.
This model can be also be converted manually to ggml using the following command:
This model can be converted to ggml using the following command:
```console
$ python3 -m venv venv && source venv/bin/activate
$ (venv) pip install silero-vad
@ -794,7 +767,7 @@ $ ./build/bin/whisper-cli \
--vad-model ./models/silero-v5.1.2-ggml.bin
```
### VAD Options
#### VAD Options
* --vad-threshold: Threshold probability for speech detection. A probability
for a speech segment/frame above this threshold will be considered as speech.

View File

@ -1,249 +1,249 @@
# whisper.cpp for SYCL
[Background](#background)
[OS](#os)
[Intel GPU](#intel-gpu)
[Linux](#linux)
[Environment Variable](#environment-variable)
[Known Issue](#known-issue)
[Todo](#todo)
## Background
SYCL is a higher-level programming model to improve programming productivity on various hardware acceleratorssuch as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17.
oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms.
Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs.
To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL.
The whisper.cpp for SYCL is used to support Intel GPUs.
For Intel CPU, recommend to use whisper.cpp for X86 (Intel MKL build).
## OS
|OS|Status|Verified|
|-|-|-|
|Linux|Support|Ubuntu 22.04|
|Windows|Ongoing| |
## Intel GPU
|Intel GPU| Status | Verified Model|
|-|-|-|
|Intel Data Center Max Series| Support| Max 1550|
|Intel Data Center Flex Series| Support| Flex 170|
|Intel Arc Series| Support| Arc 770|
|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake|
|Intel iGPU| Support| iGPU in i5-1250P, i7-1165G7|
## Linux
### Setup Environment
1. Install Intel GPU driver.
a. Please install Intel GPU driver by official guide: [Install GPU Drivers](https://dgpu-docs.intel.com/driver/installation.html).
Note: for iGPU, please install the client GPU driver.
b. Add user to group: video, render.
```
sudo usermod -aG render username
sudo usermod -aG video username
```
Note: re-login to enable it.
c. Check
```
sudo apt install clinfo
sudo clinfo -l
```
Output (example):
```
Platform #0: Intel(R) OpenCL Graphics
`-- Device #0: Intel(R) Arc(TM) A770 Graphics
Platform #0: Intel(R) OpenCL HD Graphics
`-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
```
2. Install Intel® oneAPI Base toolkit.
a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
Recommend to install to default folder: **/opt/intel/oneapi**.
Following guide use the default folder as example. If you use other folder, please modify the following guide info with your folder.
b. Check
```
source /opt/intel/oneapi/setvars.sh
sycl-ls
```
There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**.
Output (example):
```
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
```
2. Build locally:
```
mkdir -p build
cd build
source /opt/intel/oneapi/setvars.sh
#for FP16
#cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON
#for FP32
cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
#build example/main only
#cmake --build . --config Release --target main
#build all binary
cmake --build . --config Release -v
```
or
```
./examples/sycl/build.sh
```
Note:
- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only.
### Run
1. Put model file to folder **models**
2. Enable oneAPI running environment
```
source /opt/intel/oneapi/setvars.sh
```
3. List device ID
Run without parameter:
```
./build/bin/ls-sycl-device
or
./build/bin/main
```
Check the ID in startup log, like:
```
found 4 SYCL devices:
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3,
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2,
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0,
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0,
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
```
|Attribute|Note|
|-|-|
|compute capability 1.3|Level-zero running time, recommended |
|compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
4. Set device ID and execute whisper.cpp
Set device ID = 0 by **GGML_SYCL_DEVICE=0**
```
GGML_SYCL_DEVICE=0 ./build/bin/main -m models/ggml-base.en.bin -f samples/jfk.wav
```
or run by script:
```
./examples/sycl/run_whisper.sh
```
5. Check the device ID in output
Like:
```
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
```
## Environment Variable
#### Build
|Name|Value|Function|
|-|-|-|
|WHISPER_SYCL|ON (mandatory)|Enable build with SYCL code path. <br>For FP32/FP16, WHISPER_SYCL=ON is mandatory.|
|WHISPER_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path.For FP32, do not set it.|
|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path|
|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path|
#### Running
|Name|Value|Function|
|-|-|-|
|GGML_SYCL_DEVICE|0 (default) or 1|Set the device id used. Check the device ids by default running output|
|GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG|
## Known Issue
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
Miss to enable oneAPI running environment.
Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
- Hang during startup
llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block.
Solution: add **--no-mmap**.
## Todo
- Support to build in Windows.
- Support multiple cards.
# whisper.cpp for SYCL
[Background](#background)
[OS](#os)
[Intel GPU](#intel-gpu)
[Linux](#linux)
[Environment Variable](#environment-variable)
[Known Issue](#known-issue)
[Todo](#todo)
## Background
SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators<EFBFBD>such as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17.
oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms.
Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs.
To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel<EFBFBD> DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL.
The whisper.cpp for SYCL is used to support Intel GPUs.
For Intel CPU, recommend to use whisper.cpp for X86 (Intel MKL build).
## OS
|OS|Status|Verified|
|-|-|-|
|Linux|Support|Ubuntu 22.04|
|Windows|Ongoing| |
## Intel GPU
|Intel GPU| Status | Verified Model|
|-|-|-|
|Intel Data Center Max Series| Support| Max 1550|
|Intel Data Center Flex Series| Support| Flex 170|
|Intel Arc Series| Support| Arc 770|
|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake|
|Intel iGPU| Support| iGPU in i5-1250P, i7-1165G7|
## Linux
### Setup Environment
1. Install Intel GPU driver.
a. Please install Intel GPU driver by official guide: [Install GPU Drivers](https://dgpu-docs.intel.com/driver/installation.html).
Note: for iGPU, please install the client GPU driver.
b. Add user to group: video, render.
```
sudo usermod -aG render username
sudo usermod -aG video username
```
Note: re-login to enable it.
c. Check
```
sudo apt install clinfo
sudo clinfo -l
```
Output (example):
```
Platform #0: Intel(R) OpenCL Graphics
`-- Device #0: Intel(R) Arc(TM) A770 Graphics
Platform #0: Intel(R) OpenCL HD Graphics
`-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
```
2. Install Intel<EFBFBD> oneAPI Base toolkit.
a. Please follow the procedure in [Get the Intel<EFBFBD> oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
Recommend to install to default folder: **/opt/intel/oneapi**.
Following guide use the default folder as example. If you use other folder, please modify the following guide info with your folder.
b. Check
```
source /opt/intel/oneapi/setvars.sh
sycl-ls
```
There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**.
Output (example):
```
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
```
2. Build locally:
```
mkdir -p build
cd build
source /opt/intel/oneapi/setvars.sh
#for FP16
#cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON
#for FP32
cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
#build example/main only
#cmake --build . --config Release --target main
#build all binary
cmake --build . --config Release -v
```
or
```
./examples/sycl/build.sh
```
Note:
- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only.
### Run
1. Put model file to folder **models**
2. Enable oneAPI running environment
```
source /opt/intel/oneapi/setvars.sh
```
3. List device ID
Run without parameter:
```
./build/bin/ls-sycl-device
or
./build/bin/main
```
Check the ID in startup log, like:
```
found 4 SYCL devices:
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3,
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2,
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0,
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0,
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
```
|Attribute|Note|
|-|-|
|compute capability 1.3|Level-zero running time, recommended |
|compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
4. Set device ID and execute whisper.cpp
Set device ID = 0 by **GGML_SYCL_DEVICE=0**
```
GGML_SYCL_DEVICE=0 ./build/bin/main -m models/ggml-base.en.bin -f samples/jfk.wav
```
or run by script:
```
./examples/sycl/run_whisper.sh
```
5. Check the device ID in output
Like:
```
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
```
## Environment Variable
#### Build
|Name|Value|Function|
|-|-|-|
|WHISPER_SYCL|ON (mandatory)|Enable build with SYCL code path. <br>For FP32/FP16, WHISPER_SYCL=ON is mandatory.|
|WHISPER_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path.For FP32, do not set it.|
|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path|
|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path|
#### Running
|Name|Value|Function|
|-|-|-|
|GGML_SYCL_DEVICE|0 (default) or 1|Set the device id used. Check the device ids by default running output|
|GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG|
## Known Issue
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
Miss to enable oneAPI running environment.
Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
- Hang during startup
llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block.
Solution: add **--no-mmap**.
## Todo
- Support to build in Windows.
- Support multiple cards.

View File

@ -15,7 +15,7 @@ BUILD_DIR := build_go
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include)
LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src):$(abspath ../../${BUILD_DIR}/ggml/src)
LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src:$(abspath ../../${BUILD_DIR}/ggml/src))
ifeq ($(GGML_CUDA),1)
LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/
@ -23,8 +23,7 @@ ifeq ($(GGML_CUDA),1)
endif
ifeq ($(UNAME_S),Darwin)
LIBRARY_PATH := $(LIBRARY_PATH):$(abspath ../../${BUILD_DIR}/ggml/src/ggml-blas):$(abspath ../../${BUILD_DIR}/ggml/src/ggml-metal)
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit -lggml-metal -lggml-blas
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
endif
all: clean whisper examples

View File

@ -9,9 +9,7 @@ import (
// CGO
/*
#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++
#cgo linux LDFLAGS: -fopenmp
#cgo darwin LDFLAGS: -lggml-metal -lggml-blas
#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ -fopenmp
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
#include <whisper.h>
#include <stdlib.h>

View File

@ -23,42 +23,26 @@ import io.github.ggerganov.whispercpp.WhisperCpp;
public class Example {
public static void main(String[] args) {
WhisperCpp whisper = new WhisperCpp();
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
long context = whisper.initContext("base.en");
try {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
whisper.initContext("../ggml-base.en.bin");
WhisperFullParams.ByValue whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// custom configuration if required
//whisperParams.n_threads = 8;
whisperParams.temperature = 0.0f;
whisperParams.temperature_inc = 0.2f;
//whisperParams.language = "en";
float[] samples = readAudio(); // divide each value by 32767.0f
List<WhisperSegment> whisperSegmentList = whisper.fullTranscribeWithTime(whisperParams, samples);
for (WhisperSegment whisperSegment : whisperSegmentList) {
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// custom configuration if required
whisperParams.temperature_inc = 0f;
long start = whisperSegment.getStart();
long end = whisperSegment.getEnd();
var samples = readAudio(); // divide each value by 32767.0f
whisper.fullTranscribe(whisperParams, samples);
String text = whisperSegment.getSentence();
System.out.println("start: "+start);
System.out.println("end: "+end);
System.out.println("text: "+text);
int segmentCount = whisper.getTextSegmentCount(context);
for (int i = 0; i < segmentCount; i++) {
String text = whisper.getTextSegment(context, i);
System.out.println(segment.getText());
}
} catch (IOException e) {
e.printStackTrace();
} finally {
whisper.close();
whisper.freeContext(context);
}
}
}
```

View File

@ -168,26 +168,23 @@ public class WhisperCpp implements AutoCloseable {
return str.toString().trim();
}
/**
* Full transcribe with time list.
*
* @param whisperParams the whisper params
* @param audioData the audio data
* @return the list
* @throws IOException the io exception
*/
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException {
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
valueParams.read();
if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}
int nSegments = lib.whisper_full_n_segments(ctx);
List<WhisperSegment> segments= new ArrayList<>(nSegments);
for (int i = 0; i < nSegments; i++) {
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
String text = lib.whisper_full_get_segment_text(ctx, i);

View File

@ -118,7 +118,7 @@ class WhisperCppTest {
float[] floats = new float[b.length / 2];
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
//params.initial_prompt = "and so my fellow Americans um, like";

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.7.6",
"version": "1.7.5",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

View File

@ -1,9 +1,6 @@
LICENSE
pkg/
lib/whisper.*
ext/examples/
ext/ggml/
ext/include/
ext/scripts/
ext/src/
test/fixtures/
ext/sources/*
!ext/sources/CMakeGraphVizOptions.cmake
ext/mkmf.log

View File

@ -24,21 +24,7 @@ or,
$ gem install whispercpp -- --enable-ggml-cuda
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options, for example:
Boolean options:
* `-DGGML_BLAS=1` -> `--enable-ggml-blas`
* `-DWHISER_COREML=OFF` -> `--disable-whisper-coreml`
Argument options:
* `-DGGML_CUDA_COMPRESSION_MODE=size` -> `--ggml-cuda-compression-mode=size`
Combination:
* `-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"` -> `--enable-ggml-cuda --cmake_cuda-architectures="86"`
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options.
For boolean options like `GGML_CUDA`, the README says `-DGGML_CUDA=1`. You need strip `-D`, prepend `--enable-` for `1` or `ON` (`--disable-` for `0` or `OFF`) and make it kebab-case: `--enable-ggml-cuda`.
For options which require arguments like `CMAKE_CUDA_ARCHITECTURES`, the README says `-DCMAKE_CUDA_ARCHITECTURES="86"`. You need strip `-D`, prepend `--`, make it kebab-case, append `=` and append argument: `--cmake-cuda-architectures="86"`.
@ -70,6 +56,17 @@ end
Some models are prepared up-front:
```ruby
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model.pre_converted_models["base"].clear_cache
```
You also can use shorthand for pre-converted models:
```ruby
@ -94,19 +91,6 @@ puts Whisper::Model.pre_converted_models.keys
# :
```
You can also retrieve each model:
```ruby
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model.pre_converted_models["base"].clear_cache
```
You can also use local model files you prepared:
```ruby
@ -127,80 +111,9 @@ See [models][] page for details.
Currently, whisper.cpp accepts only 16-bit WAV files.
### Voice Activity Detection (VAD) ###
Support for Voice Activity Detection (VAD) can be enabled by setting `Whisper::Params`'s `vad` argument to `true` and specifying VAD model:
```ruby
Whisper::Params.new(
vad: true,
vad_model_path: "silero-v5.1.2",
# other arguments...
)
```
When you pass the model name (`"silero-v5.1.2"`) or URI (`https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin`), it will be downloaded automatically.
Currently, "silero-v5.1.2" is registered as pre-converted model like ASR models. You also specify file path or URI of model.
If you need configure VAD behavior, pass params for that:
```ruby
Whisper::Params.new(
vad: true,
vad_model_path: "silero-v5.1.2",
vad_params: Whisper::VAD::Params.new(
threshold: 1.0, # defaults to 0.5
min_speech_duration_ms: 500, # defaults to 250
min_silence_duration_ms: 200, # defaults to 100
max_speech_duration_s: 30000, # default is FLT_MAX,
speech_pad_ms: 50, # defaults to 30
samples_overlap: 0.5 # defaults to 0.1
),
# other arguments...
)
```
For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisper.cpp?tab=readme-ov-file#voice-activity-detection-vad).
### Output ###
whispercpp supports SRT and WebVTT output:
```ruby
puts whisper.transcribe("path/to/audio.wav", Whisper::Params.new).to_webvtt
# =>
WEBVTT
1
00:00:00.000 --> 00:00:03.860
My thought I have nobody by a beauty and will as you poured.
2
00:00:03.860 --> 00:00:09.840
Mr. Rochester is sub in that so-don't find simplest, and devoted about, to let might in
3
00:00:09.840 --> 00:00:09.940
a
```
You may call `#to_srt`, too
API
---
### Transcription ###
By default, `Whisper::Context#transcribe` works in a single thread. You can make it work in parallel by passing `n_processors` option:
```ruby
whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors)
```
Note that transcription occasionally might be low accuracy when it works in parallel.
### Segments ###
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
@ -222,7 +135,7 @@ whisper
ed: format_time(segment.end_time),
text: segment.text
}
line << " (speaker turned)" if segment.speaker_turn_next?
line << " (speaker turned)" if segment.speaker_next_turn?
puts line
end
@ -238,7 +151,7 @@ params.on_new_segment do |segment|
ed: format_time(segment.end_time),
text: segment.text
}
line << " (speaker turned)" if segment.speaker_turn_next?
line << " (speaker turned)" if segment.speaker_next_turn?
puts line
end
@ -335,11 +248,6 @@ First call of `rake test` builds an extension and downloads a model for testing.
If something seems wrong on build, running `rake clean` solves some cases.
### Need help ###
* Windows support
* Refinement of C/C++ code, especially memory management
License
-------

View File

@ -67,30 +67,17 @@ file LIB_FILE => [SO_FILE, "lib"] do |t|
end
CLEAN.include LIB_FILE
Rake::TestTask.new
TEST_FIXTURE_AUDIO = "test/fixtures/jfk.wav"
TEST_FIXTURE_AUDIO_SRC = File.expand_path(File.join(__dir__, "..", "..", "samples", "jfk.wav"))
TEST_FIXTURE_AUDIO_DIR = TEST_FIXTURE_AUDIO.pathmap("%d")
directory TEST_FIXTURE_AUDIO_DIR
if File.exist? TEST_FIXTURE_AUDIO_SRC
file TEST_FIXTURE_AUDIO => [TEST_FIXTURE_AUDIO_SRC, TEST_FIXTURE_AUDIO_DIR] do |t|
symlink t.source, t.name
end
else
require "open-uri"
file TEST_FIXTURE_AUDIO => TEST_FIXTURE_AUDIO_DIR do |t|
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/samples/jfk.wav").read
end
Rake::TestTask.new do |t|
t.test_files = FileList["tests/test_*.rb"]
end
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
chdir "test/jfk_reader" do
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
chdir "tests/jfk_reader" do
ruby "extconf.rb"
sh "make"
end
end
CLEAN.include TEST_MEMORY_VIEW
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO]
task test: [LIB_FILE, TEST_MEMORY_VIEW]

View File

@ -2,8 +2,10 @@ Makefile
whisper.so
whisper.bundle
whisper.dll
scripts/get-flags.mk
*.o
*.a
sources/*
!sources/CMakeGraphVizOptions.cmake
mkmf.log
/*/**/*.c
/*/**/*.cpp
/*/**/*.h
/*/**/*.m
/*/**/*.metal

View File

@ -1,32 +1,16 @@
require "tsort"
class Dependencies
include TSort
def initialize(cmake, options)
@cmake = cmake
@options = options
@static_lib_shape = nil
@nodes = {}
@graph = Hash.new {|h, k| h[k] = []}
generate_dot
parse_dot
end
def libs
tsort.filter_map {|node|
label, shape = @nodes[node]
if shape == @static_lib_shape
label.gsub(/\\n\([^)]+\)/, '')
else
nil
end
}.reverse.collect {|lib| "lib#{lib}.a"}
@libs = parse_dot
end
def to_s
libs.join(" ")
@libs.join(" ")
end
private
@ -36,38 +20,42 @@ class Dependencies
end
def generate_dot
args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"]
args << @options.to_s unless @options.to_s.empty?
system @cmake, *args, exception: true
system @cmake, "-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF", @options.to_s, exception: true
end
def parse_dot
static_lib_shape = nil
nodes = {}
depends = Hash.new {|h, k| h[k] = []}
class << depends
include TSort
alias tsort_each_node each_key
def tsort_each_child(node, &block)
fetch(node, []).each(&block)
end
end
File.open(dot_path).each_line do |line|
case line
when /\[\s*label\s*=\s*"Static Library"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]/
@static_lib_shape = $~[:shape]
static_lib_shape = $~[:shape]
when /\A\s*"(?<node>\w+)"\s*\[\s*label\s*=\s*"(?<label>\S+)"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]\s*;\s*\z/
node = $~[:node]
label = $~[:label]
shape = $~[:shape]
@nodes[node] = [label, shape]
nodes[node] = [label, shape]
when /\A\s*"(?<depender>\w+)"\s*->\s*"(?<dependee>\w+)"/
depender = $~[:depender]
dependee = $~[:dependee]
@graph[depender] << dependee
depends[depender] ||= []
depends[depender] << dependee
end
end
end
def tsort_each_node
@nodes.each_key do |node|
yield node
end
end
def tsort_each_child(node)
@graph[node].each do |child|
yield child
end
depends.tsort.filter_map {|node|
label, shape = nodes[node]
shape == static_lib_shape ? label : nil
}.collect {|lib| "lib#{lib}.a"}
.reverse
end
end

View File

@ -3,7 +3,7 @@ require_relative "options"
require_relative "dependencies"
cmake = find_executable("cmake") || abort
options = Options.new(cmake)
options = Options.new
have_library("gomp") rescue nil
libs = Dependencies.new(cmake, options)

View File

@ -1,11 +1,25 @@
class Options
def initialize(cmake="cmake")
@cmake = cmake
def initialize
@options = {}
@pending_options = []
@ignored_options = []
configure
end
def help
@options
.collect_concat {|name, (type, value)|
option = option_name(name)
if type == :bool
["--enable-#{option}", "--disable-#{option}"]
else
"--#{option}=#{type.upcase}"
end
}
.join($/)
end
def to_s
@options
.reject {|name, (type, value)| value.nil?}
@ -18,68 +32,188 @@ class Options
output = nil
Dir.chdir __dir__ do
output = `#{@cmake.shellescape} -S sources -B build -L`
output = `cmake -S sources -B build -L`
end
@cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1)
.filter_map {|line|
option, value = line.chomp.split("=", 2)
name, type = option.split(":", 2)
[
name,
[
type,
type == "BOOL" ? value == "ON" : value
]
]
}.to_h
started = false
@cmake_options = output.lines.filter_map {|line|
if line.chomp == "-- Cache values"
started = true
next
end
next unless started
option, value = line.chomp.split("=", 2)
name, type = option.split(":", 2)
[name, type, value]
}
end
def missing_options
cmake_options.collect {|name, type, value| name} -
@options.keys - @pending_options - @ignored_options
end
def extra_options
@options.keys + @pending_options + @ignored_options -
cmake_options.collect {|name, type, value| name}
end
private
def configure
cmake_options.each_pair do |name, (type, default_value)|
option = option_name(name)
value = type == "BOOL" ? enable_config(option) : arg_config("--#{option}")
@options[name] = [type, value]
end
configure_accelerate
configure_metal
configure_coreml
end
# See ggml/src/ggml-cpu/CMakeLists.txt
def configure_accelerate
if RUBY_PLATFORM.match?(/darwin/) && enabled?("GGML_ACCELERATE")
$LDFLAGS << " -framework Accelerate"
end
end
# See ggml/src/ggml-metal/CMakeLists.txt
def configure_metal
$LDFLAGS << " -framework Foundation -framework Metal -framework MetalKit" if enabled?("GGML_METAL")
end
# See src/CmakeLists.txt
def configure_coreml
if enabled?("WHISPER_COREML")
$LDFLAGS << " -framework Foundation -framework CoreML"
$defs << "-DRUBY_WHISPER_USE_COREML"
end
filepath "ACCELERATE_FRAMEWORK"
ignored "BUILD_SHARED_LIBS"
ignored "BUILD_TESTING"
ignored "CMAKE_BUILD_TYPE"
ignored "CMAKE_INSTALL_PREFIX"
string "CMAKE_OSX_ARCHITECTURES"
ignored "CMAKE_OSX_DEPLOYMENT_TARGET"
string "CMAKE_OSX_SYSROOT"
filepath "FOUNDATION_LIBRARY"
bool "GGML_ACCELERATE"
bool "GGML_ALL_WARNINGS_3RD_PARTY"
bool "GGML_AMX_BF16"
bool "GGML_AMX_INT8"
bool "GGML_AMX_TILE"
bool "GGML_AVX"
bool "GGML_AVX2"
bool "GGML_AVX512"
bool "GGML_AVX512_BF16"
bool "GGML_AVX512_VBMI"
bool "GGML_AVX512_VNNI"
bool "GGML_AVX_VNNI"
ignored "GGML_BACKEND_DL"
ignored "GGML_BIN_INSTALL_DIR"
bool "GGML_BLAS"
string "GGML_BLAS_VENDOR"
bool "GGML_BMI2"
ignored "GGML_BUILD_EXAMPLES"
ignored "GGML_BUILD_TESTS"
bool "GGML_CCACHE"
filepath "GGML_CCACHE_FOUND"
bool "GGML_CPU"
bool "GGML_CPU_AARCH64"
ignored "GGML_CPU_ALL_VARIANTS"
string "GGML_CPU_ARM_ARCH"
bool "GGML_CPU_HBM"
bool "GGML_CPU_KLEIDIAI"
string "GGML_CPU_POWERPC_CPUTYPE"
bool "GGML_CUDA"
string "GGML_CUDA_COMPRESSION_MODE"
bool "GGML_CUDA_F16"
bool "GGML_CUDA_FA"
bool "GGML_CUDA_FA_ALL_QUANTS"
bool "GGML_CUDA_FORCE_CUBLAS"
bool "GGML_CUDA_FORCE_MMQ"
ignored "GGML_CUDA_GRAPHS"
bool "GGML_CUDA_NO_PEER_COPY"
bool "GGML_CUDA_NO_VMM"
string "GGML_CUDA_PEER_MAX_BATCH_SIZE"
bool "GGML_F16C"
bool "GGML_FMA"
bool "GGML_GPROF"
bool "GGML_HIP"
bool "GGML_HIP_GRAPHS"
bool "GGML_HIP_NO_VMM"
bool "GGML_HIP_ROCWMMA_FATTN"
ignored "GGML_INCLUDE_INSTALL_DIR"
bool "GGML_KOMPUTE"
bool "GGML_LASX"
ignored "GGML_LIB_INSTALL_DIR"
ignored "GGML_LLAMAFILE"
bool "GGML_LSX"
bool "GGML_LTO"
bool "GGML_METAL"
bool "GGML_METAL_EMBED_LIBRARY"
string "GGML_METAL_MACOSX_VERSION_MIN"
bool "GGML_METAL_NDEBUG"
bool "GGML_METAL_SHADER_DEBUG"
string "GGML_METAL_STD"
bool "GGML_METAL_USE_BF16"
bool "GGML_MUSA"
bool "GGML_NATIVE"
bool "GGML_OPENCL"
bool "GGML_OPENCL_EMBED_KERNELS"
bool "GGML_OPENCL_PROFILING"
string "GGML_OPENCL_TARGET_VERSION"
bool "GGML_OPENCL_USE_ADRENO_KERNELS"
bool "GGML_OPENMP"
bool "GGML_RPC"
bool "GGML_RVV"
bool "GGML_RV_ZFH"
pending "GGML_SCCACHE_FOUND"
string "GGML_SCHED_MAX_COPIES"
bool "GGML_SSE42"
ignored "GGML_STATIC"
bool "GGML_SYCL"
string "GGML_SYCL_DEVICE_ARCH"
bool "GGML_SYCL_F16"
bool "GGML_SYCL_GRAPH"
string "GGML_SYCL_TARGET"
bool "GGML_VULKAN"
bool "GGML_VULKAN_CHECK_RESULTS"
bool "GGML_VULKAN_DEBUG"
bool "GGML_VULKAN_MEMORY_DEBUG"
bool "GGML_VULKAN_PERF"
ignored "GGML_VULKAN_RUN_TESTS"
filepath "GGML_VULKAN_SHADERS_GEN_TOOLCHAIN"
bool "GGML_VULKAN_SHADER_DEBUG_INFO"
pending "GGML_VULKAN_VALIDATE"
bool "GGML_VXE"
filepath "GIT_EXE"
filepath "MATH_LIBRARY"
filepath "METALKIT_FRAMEWORK"
filepath "METAL_FRAMEWORK"
bool "WHISPER_ALL_WARNINGS"
bool "WHISPER_ALL_WARNINGS_3RD_PARTY"
ignored "WHISPER_BIN_INSTALL_DIR"
ignored "WHISPER_BUILD_EXAMPLES"
ignored "WHISPER_BUILD_SERVER"
ignored"WHISPER_BUILD_TESTS"
bool "WHISPER_COREML"
bool "WHISPER_COREML_ALLOW_FALLBACK"
ignored "WHISPER_CURL"
bool "WHISPER_FATAL_WARNINGS"
ignored "WHISPER_FFMPEG"
ignored "WHISPER_INCLUDE_INSTALL_DIR"
ignored "WHISPER_LIB_INSTALL_DIR"
bool "WHISPER_OPENVINO"
bool "WHISPER_SANITIZE_ADDRESS"
bool "WHISPER_SANITIZE_THREAD"
bool "WHISPER_SANITIZE_UNDEFINED"
ignored "WHISPER_SDL2"
pending "WHISPER_USE_SYSTEM_GGML"
end
def option_name(name)
name.downcase.gsub("_", "-")
end
def enabled?(option)
op = @options[option]
raise "Option not exist: #{option}" unless op
raise "Option not boolean: #{option}(#{op[0]})" unless op[0] == "BOOL"
if op[1].nil?
cmake_options[option][1]
else
op[1]
end
def bool(name)
option = option_name(name)
value = enable_config(option)
@options[name] = [:bool, value]
end
def string(name, type=:string)
option = "--#{option_name(name)}"
value = arg_config(option)
raise "String expected for #{option}" if value == true || value&.empty?
@options[name] = [type, value]
end
def path(name)
string(name, :path)
end
def filepath(name)
string(name, :filepath)
end
def pending(name)
@pending_options << name
end
def ignored(name)
@ignored_options << name
end
end

View File

@ -3,10 +3,8 @@
#include "ruby_whisper.h"
VALUE mWhisper;
VALUE mVAD;
VALUE cContext;
VALUE cParams;
VALUE cVADParams;
VALUE eError;
VALUE cSegment;
@ -22,9 +20,6 @@ ID id_new;
ID id_to_path;
ID id_URI;
ID id_pre_converted_models;
ID id_coreml_compiled_models;
ID id_cache;
ID id_n_processors;
static bool is_log_callback_finalized = false;
@ -36,7 +31,6 @@ extern void init_ruby_whisper_params(VALUE *mWhisper);
extern void init_ruby_whisper_error(VALUE *mWhisper);
extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
extern void init_ruby_whisper_model(VALUE *mWhisper);
extern void init_ruby_whisper_vad_params(VALUE *mVAD);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
/*
@ -86,14 +80,6 @@ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
return rb_str_new2(str_full);
}
/*
* call-seq:
* system_info_str -> String
*/
static VALUE ruby_whisper_s_system_info_str(VALUE self) {
return rb_str_new2(whisper_print_system_info());
}
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
is_log_callback_finalized = true;
return Qnil;
@ -130,6 +116,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
return Qnil;
}
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm;
rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
}
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
@ -141,14 +137,9 @@ void Init_whisper() {
id_to_path = rb_intern("to_path");
id_URI = rb_intern("URI");
id_pre_converted_models = rb_intern("pre_converted_models");
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
id_cache = rb_intern("cache");
id_n_processors = rb_intern("n_processors");
mWhisper = rb_define_module("Whisper");
mVAD = rb_define_module_under(mWhisper, "VAD");
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
@ -160,7 +151,6 @@ void Init_whisper() {
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
@ -169,9 +159,6 @@ void Init_whisper() {
init_ruby_whisper_error(&mWhisper);
init_ruby_whisper_segment(&mWhisper, &cContext);
init_ruby_whisper_model(&mWhisper);
init_ruby_whisper_vad_params(&mVAD);
rb_require("whisper/context");
rb_require("whisper/segment");
rb_require("whisper/model/uri");
}

View File

@ -21,13 +21,8 @@ typedef struct {
ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container;
VALUE vad_params;
} ruby_whisper_params;
typedef struct {
struct whisper_vad_params params;
} ruby_whisper_vad_params;
typedef struct {
VALUE context;
int index;

View File

@ -11,21 +11,15 @@ extern ID id_new;
extern ID id_to_path;
extern ID id_URI;
extern ID id_pre_converted_models;
extern ID id_coreml_compiled_models;
extern ID id_cache;
extern ID id_n_processors;
extern VALUE cContext;
extern VALUE eError;
extern VALUE cModel;
extern const rb_data_type_t ruby_whisper_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_s_new(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
ID transcribe_option_names[1];
extern VALUE rb_whisper_model_initialize(VALUE context);
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
static void
ruby_whisper_free(ruby_whisper *rw)
@ -43,74 +37,19 @@ rb_whisper_mark(ruby_whisper *rw)
}
void
rb_whisper_free(void *p)
rb_whisper_free(ruby_whisper *rw)
{
ruby_whisper *rw = (ruby_whisper *)p;
ruby_whisper_free(rw);
free(rw);
}
static size_t
ruby_whisper_memsize(const void *p)
{
const ruby_whisper *rw = (const ruby_whisper *)p;
size_t size = sizeof(rw);
if (!rw) {
return 0;
}
if (rw->context) {
size += sizeof(rw->context);
}
return size;
}
const rb_data_type_t ruby_whisper_type = {
"ruby_whisper",
{0, rb_whisper_free, ruby_whisper_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_allocate(VALUE klass)
{
ruby_whisper *rw;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper, &ruby_whisper_type, rw);
rw = ALLOC(ruby_whisper);
rw->context = NULL;
return obj;
}
VALUE
ruby_whisper_normalize_model_path(VALUE model_path)
{
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, model_path);
if (!NIL_P(pre_converted_model)) {
model_path = pre_converted_model;
#ifdef RUBY_WHISPER_USE_COREML
VALUE coreml_converted_models = rb_funcall(cModel, id_coreml_compiled_models, 0);
VALUE coreml_converted_model = rb_hash_aref(coreml_converted_models, pre_converted_model);
if (!NIL_P(coreml_converted_model)) {
rb_funcall(coreml_converted_model, id_cache, 0);
}
#endif
}
else if (TYPE(model_path) == T_STRING) {
const char * model_path_str = StringValueCStr(model_path);
if (strncmp("http://", model_path_str, 7) == 0 || strncmp("https://", model_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
model_path = rb_class_new_instance(1, &model_path, uri_class);
}
}
else if (rb_obj_is_kind_of(model_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
model_path = rb_class_new_instance(1, &model_path, uri_class);
}
if (rb_respond_to(model_path, id_to_path)) {
model_path = rb_funcall(model_path, id_to_path, 0);
}
return model_path;
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
}
/*
@ -127,9 +66,27 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (TYPE(whisper_model_file_path) == T_STRING) {
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
}
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
@ -147,7 +104,7 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
VALUE ruby_whisper_model_n_vocab(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
@ -158,7 +115,7 @@ VALUE ruby_whisper_model_n_vocab(VALUE self)
VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
@ -169,7 +126,7 @@ VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
VALUE ruby_whisper_model_n_audio_state(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
@ -180,7 +137,7 @@ VALUE ruby_whisper_model_n_audio_state(VALUE self)
VALUE ruby_whisper_model_n_audio_head(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
@ -191,7 +148,7 @@ VALUE ruby_whisper_model_n_audio_head(VALUE self)
VALUE ruby_whisper_model_n_audio_layer(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
@ -202,7 +159,7 @@ VALUE ruby_whisper_model_n_audio_layer(VALUE self)
VALUE ruby_whisper_model_n_text_ctx(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
@ -213,7 +170,7 @@ VALUE ruby_whisper_model_n_text_ctx(VALUE self)
VALUE ruby_whisper_model_n_text_state(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
@ -224,7 +181,7 @@ VALUE ruby_whisper_model_n_text_state(VALUE self)
VALUE ruby_whisper_model_n_text_head(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
@ -235,7 +192,7 @@ VALUE ruby_whisper_model_n_text_head(VALUE self)
VALUE ruby_whisper_model_n_text_layer(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
@ -246,7 +203,7 @@ VALUE ruby_whisper_model_n_text_layer(VALUE self)
VALUE ruby_whisper_model_n_mels(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
@ -257,7 +214,7 @@ VALUE ruby_whisper_model_n_mels(VALUE self)
VALUE ruby_whisper_model_ftype(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
@ -268,7 +225,7 @@ VALUE ruby_whisper_model_ftype(VALUE self)
VALUE ruby_whisper_model_type(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
@ -291,9 +248,9 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
ruby_whisper *rw;
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
rb_memory_view_t view;
@ -308,20 +265,13 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
// Should check when samples.respond_to?(:length)?
} else {
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)RARRAY_LEN(samples);
n_samples = RARRAY_LEN(samples);
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
ssize_t n_samples_size = view.byte_size / view.item_size;
if (n_samples_size > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)n_samples_size;
n_samples = view.byte_size / view.item_size;
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
@ -346,7 +296,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
}
}
}
prepare_transcription(rwp, &self);
register_callbacks(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return self;
@ -377,9 +327,9 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
ruby_whisper *rw;
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
int n_processors;
@ -409,17 +359,10 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
ssize_t n_samples_size = view.byte_size / view.item_size;
if (n_samples_size > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)n_samples_size;
n_samples = view.byte_size / view.item_size;
} else {
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)RARRAY_LEN(samples);
n_samples = RARRAY_LEN(samples);
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
@ -444,7 +387,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
}
}
}
prepare_transcription(rwp, &self);
register_callbacks(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return self;
@ -463,7 +406,7 @@ static VALUE
ruby_whisper_full_n_segments(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_n_segments(rw->context));
}
@ -477,7 +420,7 @@ static VALUE
ruby_whisper_full_lang_id(VALUE self)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_lang_id(rw->context));
}
@ -502,10 +445,10 @@ static VALUE
ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
return LONG2NUM(t0);
return INT2NUM(t0);
}
/*
@ -520,10 +463,10 @@ static VALUE
ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
return LONG2NUM(t1);
return INT2NUM(t1);
}
/*
@ -538,7 +481,7 @@ static VALUE
ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
return speaker_turn_next ? Qtrue : Qfalse;
@ -556,7 +499,7 @@ static VALUE
ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
return rb_str_new2(text);
@ -570,7 +513,7 @@ static VALUE
ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob);
@ -581,7 +524,7 @@ ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
static VALUE
ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
{
return rb_whisper_segment_s_new(self, NUM2INT(i_segment));
return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
}
/*
@ -611,11 +554,11 @@ ruby_whisper_each_segment(VALUE self)
}
ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(self, ruby_whisper, rw);
const int n_segments = whisper_full_n_segments(rw->context);
for (int i = 0; i < n_segments; ++i) {
rb_yield(rb_whisper_segment_s_new(self, i));
rb_yield(rb_whisper_segment_initialize(self, i));
}
return self;
@ -628,7 +571,7 @@ ruby_whisper_each_segment(VALUE self)
static VALUE
ruby_whisper_get_model(VALUE self)
{
return rb_whisper_model_s_new(self);
return rb_whisper_model_initialize(self);
}
void
@ -636,8 +579,6 @@ init_ruby_whisper_context(VALUE *mWhisper)
{
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
transcribe_option_names[0] = id_n_processors;
rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
@ -664,7 +605,7 @@ init_ruby_whisper_context(VALUE *mWhisper)
rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
// High level
// High leve
rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);

View File

@ -1,44 +1,22 @@
#include <ruby.h>
#include "ruby_whisper.h"
extern const rb_data_type_t ruby_whisper_type;
extern VALUE cModel;
static void rb_whisper_model_mark(void *p) {
ruby_whisper_model *rwm = (ruby_whisper_model *)p;
if (rwm->context) {
rb_gc_mark(rwm->context);
}
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
static size_t
ruby_whisper_model_memsize(const void *p)
{
const ruby_whisper_model *rwm = (const ruby_whisper_model *)p;
size_t size = sizeof(rwm);
if (!rwm) {
return 0;
}
return size;
}
static const rb_data_type_t rb_whisper_model_type = {
"ruby_whisper_model",
{rb_whisper_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_model_memsize,},
0, 0,
0
};
static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm;
return TypedData_Make_Struct(klass, ruby_whisper_model, &rb_whisper_model_type, rwm);
rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
}
VALUE rb_whisper_model_s_new(VALUE context) {
VALUE rb_whisper_model_initialize(VALUE context) {
ruby_whisper_model *rwm;
const VALUE model = ruby_whisper_model_allocate(cModel);
TypedData_Get_Struct(model, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(model, ruby_whisper_model, rwm);
rwm->context = context;
return model;
};
@ -51,9 +29,9 @@ static VALUE
ruby_whisper_model_n_vocab(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
@ -65,9 +43,9 @@ static VALUE
ruby_whisper_model_n_audio_ctx(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
@ -79,9 +57,9 @@ static VALUE
ruby_whisper_model_n_audio_state(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
@ -93,9 +71,9 @@ static VALUE
ruby_whisper_model_n_audio_head(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
@ -107,9 +85,9 @@ static VALUE
ruby_whisper_model_n_audio_layer(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
@ -121,9 +99,9 @@ static VALUE
ruby_whisper_model_n_text_ctx(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
@ -135,9 +113,9 @@ static VALUE
ruby_whisper_model_n_text_state(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
@ -149,9 +127,9 @@ static VALUE
ruby_whisper_model_n_text_head(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
@ -163,9 +141,9 @@ static VALUE
ruby_whisper_model_n_text_layer(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
@ -177,9 +155,9 @@ static VALUE
ruby_whisper_model_n_mels(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
@ -191,9 +169,9 @@ static VALUE
ruby_whisper_model_ftype(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
@ -205,9 +183,9 @@ static VALUE
ruby_whisper_model_type(VALUE self)
{
ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}

View File

@ -3,7 +3,7 @@
#define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \
rwp->params.prop = false; \
} else { \
@ -13,7 +13,7 @@
#define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \
Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params.prop) { \
return Qtrue; \
} else { \
@ -26,16 +26,13 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
extern VALUE cParams;
extern VALUE cVADParams;
extern ID id_call;
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern const rb_data_type_t ruby_whisper_vad_params_type;
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
static ID id_language;
@ -49,7 +46,6 @@ static ID id_print_timestamps;
static ID id_suppress_blank;
static ID id_suppress_nst;
static ID id_token_timestamps;
static ID id_max_len;
static ID id_split_on_word;
static ID id_initial_prompt;
static ID id_diarize;
@ -71,15 +67,10 @@ static ID id_encoder_begin_callback;
static ID id_encoder_begin_callback_user_data;
static ID id_abort_callback;
static ID id_abort_callback_user_data;
static ID id_vad;
static ID id_vad_model_path;
static ID id_vad_params;
static void
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
{
if (rwc == NULL) return;
rb_gc_mark(rwc->user_data);
rb_gc_mark(rwc->callback);
rb_gc_mark(rwc->callbacks);
@ -111,7 +102,7 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta
const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_s_new(*container->context, i_segment);
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment);
@ -186,7 +177,7 @@ static bool abort_callback(void * user_data) {
return false;
}
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->new_segment_callback_container->context = context;
rwp->params.new_segment_callback = new_segment_callback;
@ -212,29 +203,13 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
}
}
static void set_vad_params(ruby_whisper_params *rwp)
{
ruby_whisper_vad_params * rwvp;
TypedData_Get_Struct(rwp->vad_params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwp->params.vad_params = rwvp->params;
}
void
prepare_transcription(ruby_whisper_params *rwp, VALUE *context)
rb_whisper_params_mark(ruby_whisper_params *rwp)
{
register_callbacks(rwp, context);
set_vad_params(rwp);
}
void
rb_whisper_params_mark(void *p)
{
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
rb_gc_mark(rwp->vad_params);
}
void
@ -243,46 +218,25 @@ ruby_whisper_params_free(ruby_whisper_params *rwp)
}
void
rb_whisper_params_free(void *p)
rb_whisper_params_free(ruby_whisper_params *rwp)
{
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
// How to free user_data and callback only when not referred to by others?
ruby_whisper_params_free(rwp);
free(rwp);
}
static size_t
ruby_whisper_params_memsize(const void *p)
{
const ruby_whisper_params *rwp = (const ruby_whisper_params *)p;
return sizeof(ruby_whisper_params) + sizeof(rwp->params) + sizeof(rwp->vad_params);
}
const rb_data_type_t ruby_whisper_params_type = {
"ruby_whisper_params",
{
rb_whisper_params_mark,
rb_whisper_params_free,
ruby_whisper_params_memsize,
},
0, 0,
0
};
static VALUE
ruby_whisper_params_allocate(VALUE klass)
{
ruby_whisper_params *rwp;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->diarize = false;
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return obj;
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
}
/*
@ -295,7 +249,7 @@ static VALUE
ruby_whisper_params_set_language(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->params.language = "auto";
} else {
@ -311,7 +265,7 @@ static VALUE
ruby_whisper_params_get_language(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params.language) {
return rb_str_new2(rwp->params.language);
} else {
@ -515,33 +469,6 @@ ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, token_timestamps, value)
}
/*
* max segment length in characters.
*
* call-seq:
* max_len -> Integer
*/
static VALUE
ruby_whisper_params_get_max_len(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return INT2NUM(rwp->params.max_len);
}
/*
* call-seq:
* max_len = length -> length
*/
static VALUE
ruby_whisper_params_set_max_len(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->params.max_len = NUM2INT(value);
return value;
}
/*
* If true, split on word rather than on token (when used with max_len).
*
@ -575,7 +502,7 @@ static VALUE
ruby_whisper_params_get_initial_prompt(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
}
/*
@ -586,7 +513,7 @@ static VALUE
ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.initial_prompt = StringValueCStr(value);
return value;
}
@ -600,7 +527,7 @@ static VALUE
ruby_whisper_params_get_diarize(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->diarize) {
return Qtrue;
} else {
@ -615,7 +542,7 @@ static VALUE
ruby_whisper_params_set_diarize(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) {
rwp->diarize = false;
} else {
@ -634,7 +561,7 @@ static VALUE
ruby_whisper_params_get_offset(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms);
}
/*
@ -645,7 +572,7 @@ static VALUE
ruby_whisper_params_set_offset(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value);
return value;
}
@ -659,7 +586,7 @@ static VALUE
ruby_whisper_params_get_duration(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms);
}
/*
@ -670,7 +597,7 @@ static VALUE
ruby_whisper_params_set_duration(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.duration_ms = NUM2INT(value);
return value;
}
@ -685,7 +612,7 @@ static VALUE
ruby_whisper_params_get_max_text_tokens(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx);
}
/*
@ -696,7 +623,7 @@ static VALUE
ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value);
return value;
}
@ -708,7 +635,7 @@ static VALUE
ruby_whisper_params_get_temperature(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.temperature);
}
/*
@ -719,7 +646,7 @@ static VALUE
ruby_whisper_params_set_temperature(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.temperature = RFLOAT_VALUE(value);
return value;
}
@ -733,7 +660,7 @@ static VALUE
ruby_whisper_params_get_max_initial_ts(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.max_initial_ts);
}
/*
@ -744,7 +671,7 @@ static VALUE
ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.max_initial_ts = RFLOAT_VALUE(value);
return value;
}
@ -756,7 +683,7 @@ static VALUE
ruby_whisper_params_get_length_penalty(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.length_penalty);
}
/*
@ -767,7 +694,7 @@ static VALUE
ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.length_penalty = RFLOAT_VALUE(value);
return value;
}
@ -779,7 +706,7 @@ static VALUE
ruby_whisper_params_get_temperature_inc(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.temperature_inc);
}
/*
@ -790,7 +717,7 @@ static VALUE
ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.temperature_inc = RFLOAT_VALUE(value);
return value;
}
@ -804,7 +731,7 @@ static VALUE
ruby_whisper_params_get_entropy_thold(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.entropy_thold);
}
/*
@ -815,7 +742,7 @@ static VALUE
ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.entropy_thold = RFLOAT_VALUE(value);
return value;
}
@ -827,7 +754,7 @@ static VALUE
ruby_whisper_params_get_logprob_thold(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.logprob_thold);
}
/*
@ -838,7 +765,7 @@ static VALUE
ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.logprob_thold = RFLOAT_VALUE(value);
return value;
}
@ -850,7 +777,7 @@ static VALUE
ruby_whisper_params_get_no_speech_thold(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.no_speech_thold);
}
/*
@ -861,7 +788,7 @@ static VALUE
ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
return value;
}
@ -869,7 +796,7 @@ static VALUE
ruby_whisper_params_get_new_segment_callback(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->new_segment_callback_container->callback;
}
/*
@ -886,7 +813,7 @@ static VALUE
ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->callback = value;
return value;
}
@ -894,7 +821,7 @@ static VALUE
ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->new_segment_callback_container->user_data;
}
/*
@ -907,7 +834,7 @@ static VALUE
ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->user_data = value;
return value;
}
@ -915,7 +842,7 @@ static VALUE
ruby_whisper_params_get_progress_callback(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->progress_callback_container->callback;
}
/*
@ -934,7 +861,7 @@ static VALUE
ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->progress_callback_container->callback = value;
return value;
}
@ -942,7 +869,7 @@ static VALUE
ruby_whisper_params_get_progress_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->progress_callback_container->user_data;
}
/*
@ -955,7 +882,7 @@ static VALUE
ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->progress_callback_container->user_data = value;
return value;
}
@ -964,7 +891,7 @@ static VALUE
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->callback;
}
@ -982,7 +909,7 @@ static VALUE
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->callback = value;
return value;
}
@ -991,7 +918,7 @@ static VALUE
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->encoder_begin_callback_container->user_data;
}
@ -1005,7 +932,7 @@ static VALUE
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->encoder_begin_callback_container->user_data = value;
return value;
}
@ -1014,7 +941,7 @@ static VALUE
ruby_whisper_params_get_abort_callback(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->abort_callback_container->callback;
}
/*
@ -1031,7 +958,7 @@ static VALUE
ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->abort_callback_container->callback = value;
return value;
}
@ -1039,7 +966,7 @@ static VALUE
ruby_whisper_params_get_abort_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->abort_callback_container->user_data;
}
/*
@ -1052,74 +979,11 @@ static VALUE
ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->abort_callback_container->user_data = value;
return value;
}
/*
* call-seq:
* vad = use_vad -> use_vad
*/
static VALUE
ruby_whisper_params_get_vad(VALUE self)
{
BOOL_PARAMS_GETTER(self, vad)
}
static VALUE
ruby_whisper_params_set_vad(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, vad, value)
}
/*
* call-seq:
* vad_model_path = model_path -> model_path
*/
static VALUE
ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
if (NIL_P(value)) {
rwp->params.vad_model_path = NULL;
return value;
}
VALUE path = ruby_whisper_normalize_model_path(value);
rwp->params.vad_model_path = StringValueCStr(path);
return value;
}
static VALUE
ruby_whisper_params_get_vad_model_path(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->params.vad_model_path == NULL ? Qnil : rb_str_new2(rwp->params.vad_model_path);
}
/*
* call-seq:
* vad_params = params -> params
*/
static VALUE
ruby_whisper_params_set_vad_params(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->vad_params = value;
return value;
}
static VALUE
ruby_whisper_params_get_vad_params(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->vad_params;
}
#define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \
ruby_whisper_params_set_ ## param_name(self, value); \
@ -1129,6 +993,7 @@ ruby_whisper_params_get_vad_params(VALUE self)
static VALUE
ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
VALUE value;
@ -1142,7 +1007,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
}
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values);
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
Data_Get_Struct(self, ruby_whisper_params, rwp);
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
id = param_names[i];
@ -1165,7 +1030,6 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(suppress_blank)
SET_PARAM_IF_SAME(suppress_nst)
SET_PARAM_IF_SAME(token_timestamps)
SET_PARAM_IF_SAME(max_len)
SET_PARAM_IF_SAME(split_on_word)
SET_PARAM_IF_SAME(initial_prompt)
SET_PARAM_IF_SAME(offset)
@ -1186,9 +1050,6 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data)
SET_PARAM_IF_SAME(vad)
SET_PARAM_IF_SAME(vad_model_path)
SET_PARAM_IF_SAME(vad_params)
}
}
@ -1210,10 +1071,10 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
static VALUE
ruby_whisper_params_on_new_segment(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rwp->new_segment_callback_container->callbacks, blk);
rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
return Qnil;
}
@ -1230,10 +1091,10 @@ ruby_whisper_params_on_new_segment(VALUE self)
static VALUE
ruby_whisper_params_on_progress(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rwp->progress_callback_container->callbacks, blk);
rb_ary_push(rws->progress_callback_container->callbacks, blk);
return Qnil;
}
@ -1250,10 +1111,10 @@ ruby_whisper_params_on_progress(VALUE self)
static VALUE
ruby_whisper_params_on_encoder_begin(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk);
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
return Qnil;
}
@ -1274,10 +1135,10 @@ ruby_whisper_params_on_encoder_begin(VALUE self)
static VALUE
ruby_whisper_params_abort_on(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
ruby_whisper_params *rws;
Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc();
rb_ary_push(rwp->abort_callback_container->callbacks, blk);
rb_ary_push(rws->abort_callback_container->callbacks, blk);
return Qnil;
}
@ -1300,31 +1161,27 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(suppress_blank, 8)
DEFINE_PARAM(suppress_nst, 9)
DEFINE_PARAM(token_timestamps, 10)
DEFINE_PARAM(max_len, 11)
DEFINE_PARAM(split_on_word, 12)
DEFINE_PARAM(initial_prompt, 13)
DEFINE_PARAM(diarize, 14)
DEFINE_PARAM(offset, 15)
DEFINE_PARAM(duration, 16)
DEFINE_PARAM(max_text_tokens, 17)
DEFINE_PARAM(temperature, 18)
DEFINE_PARAM(max_initial_ts, 19)
DEFINE_PARAM(length_penalty, 20)
DEFINE_PARAM(temperature_inc, 21)
DEFINE_PARAM(entropy_thold, 22)
DEFINE_PARAM(logprob_thold, 23)
DEFINE_PARAM(no_speech_thold, 24)
DEFINE_PARAM(new_segment_callback, 25)
DEFINE_PARAM(new_segment_callback_user_data, 26)
DEFINE_PARAM(progress_callback, 27)
DEFINE_PARAM(progress_callback_user_data, 28)
DEFINE_PARAM(encoder_begin_callback, 29)
DEFINE_PARAM(encoder_begin_callback_user_data, 30)
DEFINE_PARAM(abort_callback, 31)
DEFINE_PARAM(abort_callback_user_data, 32)
DEFINE_PARAM(vad, 33)
DEFINE_PARAM(vad_model_path, 34)
DEFINE_PARAM(vad_params, 35)
DEFINE_PARAM(split_on_word, 11)
DEFINE_PARAM(initial_prompt, 12)
DEFINE_PARAM(diarize, 13)
DEFINE_PARAM(offset, 14)
DEFINE_PARAM(duration, 15)
DEFINE_PARAM(max_text_tokens, 16)
DEFINE_PARAM(temperature, 17)
DEFINE_PARAM(max_initial_ts, 18)
DEFINE_PARAM(length_penalty, 19)
DEFINE_PARAM(temperature_inc, 20)
DEFINE_PARAM(entropy_thold, 21)
DEFINE_PARAM(logprob_thold, 22)
DEFINE_PARAM(no_speech_thold, 23)
DEFINE_PARAM(new_segment_callback, 24)
DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(encoder_begin_callback, 28)
DEFINE_PARAM(encoder_begin_callback_user_data, 29)
DEFINE_PARAM(abort_callback, 30)
DEFINE_PARAM(abort_callback_user_data, 31)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);

View File

@ -1,57 +1,28 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define N_KEY_NAMES 5
static VALUE sym_start_time;
static VALUE sym_end_time;
static VALUE sym_text;
static VALUE sym_no_speech_prob;
static VALUE sym_speaker_turn_next;
static VALUE key_names;
extern const rb_data_type_t ruby_whisper_type;
extern VALUE cSegment;
static void
rb_whisper_segment_mark(void *p)
rb_whisper_segment_mark(ruby_whisper_segment *rws)
{
ruby_whisper_segment *rws = (ruby_whisper_segment *)p;
rb_gc_mark(rws->context);
}
static size_t
ruby_whisper_segment_memsize(const void *p)
{
const ruby_whisper_segment *rws = (const ruby_whisper_segment *)p;
size_t size = sizeof(rws);
if (!rws) {
return 0;
}
return size;
}
static const rb_data_type_t ruby_whisper_segment_type = {
"ruby_whisper_segment",
{rb_whisper_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_segment_memsize,},
0, 0,
0
};
VALUE
ruby_whisper_segment_allocate(VALUE klass)
{
ruby_whisper_segment *rws;
return TypedData_Make_Struct(klass, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
rws = ALLOC(ruby_whisper_segment);
return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
}
VALUE
rb_whisper_segment_s_new(VALUE context, int index)
rb_whisper_segment_initialize(VALUE context, int index)
{
ruby_whisper_segment *rws;
const VALUE segment = ruby_whisper_segment_allocate(cSegment);
TypedData_Get_Struct(segment, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
Data_Get_Struct(segment, ruby_whisper_segment, rws);
rws->context = context;
rws->index = index;
return segment;
@ -67,12 +38,12 @@ static VALUE
ruby_whisper_segment_get_start_time(VALUE self)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return LONG2NUM(t0 * 10);
return INT2NUM(t0 * 10);
}
/*
@ -85,12 +56,12 @@ static VALUE
ruby_whisper_segment_get_end_time(VALUE self)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return LONG2NUM(t1 * 10);
return INT2NUM(t1 * 10);
}
/*
@ -103,9 +74,9 @@ static VALUE
ruby_whisper_segment_get_speaker_turn_next(VALUE self)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rws->context, ruby_whisper, rw);
return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
}
@ -117,9 +88,9 @@ static VALUE
ruby_whisper_segment_get_text(VALUE self)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rws->context, ruby_whisper, rw);
const char * text = whisper_full_get_segment_text(rw->context, rws->index);
return rb_str_new2(text);
}
@ -132,89 +103,21 @@ static VALUE
ruby_whisper_segment_get_no_speech_prob(VALUE self)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
Data_Get_Struct(rws->context, ruby_whisper, rw);
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
}
/*
* call-seq:
* deconstruct_keys(keys) -> hash
*
* Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
*
* whisper.each_segment do |segment|
* segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
*
* puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
* end
*/
static VALUE
ruby_whisper_segment_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
VALUE hash = rb_hash_new();
long n_keys;
if (NIL_P(keys)) {
keys = key_names;
n_keys = N_KEY_NAMES;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > N_KEY_NAMES) {
return hash;
}
}
for (int i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
if (key == sym_start_time) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_start_time(self));
}
if (key == sym_end_time) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_end_time(self));
}
if (key == sym_text) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_text(self));
}
if (key == sym_no_speech_prob) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_no_speech_prob(self));
}
if (key == sym_speaker_turn_next) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_speaker_turn_next(self));
}
}
return hash;
}
void
init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
{
cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
sym_start_time = ID2SYM(rb_intern("start_time"));
sym_end_time = ID2SYM(rb_intern("end_time"));
sym_text = ID2SYM(rb_intern("text"));
sym_no_speech_prob = ID2SYM(rb_intern("no_speech_prob"));
sym_speaker_turn_next = ID2SYM(rb_intern("speaker_turn_next"));
key_names = rb_ary_new3(
N_KEY_NAMES,
sym_start_time,
sym_end_time,
sym_text,
sym_no_speech_prob,
sym_speaker_turn_next
);
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_turn_next?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
rb_define_method(cSegment, "deconstruct_keys", ruby_whisper_segment_deconstruct_keys, 1);
}

View File

@ -8,15 +8,11 @@
extern "C" {
#endif
extern const rb_data_type_t ruby_whisper_type;
extern const rb_data_type_t ruby_whisper_params_type;
extern ID id_to_s;
extern ID id_call;
extern ID transcribe_option_names[1];
extern void
prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
register_callbacks(ruby_whisper_params * rwp, VALUE * self);
/*
* transcribe a single file
@ -35,16 +31,11 @@ VALUE
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params, kws;
VALUE opts[1];
VALUE wave_file_path, blk, params;
rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, &params, &kws, &blk);
rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts);
int n_processors = opts[0] == Qundef ? 1 : NUM2INT(opts[0]);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp);
if (!rb_respond_to(wave_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
@ -70,22 +61,22 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }
prepare_transcription(rwp, &self);
register_callbacks(rwp, &self);
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) {
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}
if (NIL_P(blk)) {
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
VALUE output = rb_str_new2("");
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text));
}
rb_funcall(blk, id_call, 1, output);
VALUE idCall = id_call;
if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output);
}
return self;
}
#ifdef __cplusplus

View File

@ -1,288 +0,0 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define DEFINE_PARAM(param_name, nth) \
id_ ## param_name = rb_intern(#param_name); \
param_names[nth] = id_ ## param_name; \
rb_define_method(cVADParams, #param_name, ruby_whisper_vad_params_get_ ## param_name, 0); \
rb_define_method(cVADParams, #param_name "=", ruby_whisper_vad_params_set_ ## param_name, 1);
#define NUM_PARAMS 6
extern VALUE cVADParams;
static size_t
ruby_whisper_vad_params_memsize(const void *p)
{
const struct ruby_whisper_vad_params *params = p;
size_t size = sizeof(params);
if (!params) {
return 0;
}
return size;
}
static ID param_names[NUM_PARAMS];
static ID id_threshold;
static ID id_min_speech_duration_ms;
static ID id_min_silence_duration_ms;
static ID id_max_speech_duration_s;
static ID id_speech_pad_ms;
static ID id_samples_overlap;
const rb_data_type_t ruby_whisper_vad_params_type = {
"ruby_whisper_vad_params",
{0, 0, ruby_whisper_vad_params_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_vad_params_s_allocate(VALUE klass)
{
ruby_whisper_vad_params *rwvp;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params = whisper_vad_default_params();
return obj;
}
/*
* Probability threshold to consider as speech.
*
* call-seq:
* threshold = th -> th
*/
static VALUE
ruby_whisper_vad_params_set_threshold(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.threshold = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_threshold(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return DBL2NUM(rwvp->params.threshold);
}
/*
* Min duration for a valid speech segment.
*
* call-seq:
* min_speech_duration_ms = duration_ms -> duration_ms
*/
static VALUE
ruby_whisper_vad_params_set_min_speech_duration_ms(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.min_speech_duration_ms = NUM2INT(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_min_speech_duration_ms(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return INT2NUM(rwvp->params.min_speech_duration_ms);
}
/*
* Min silence duration to consider speech as ended.
*
* call-seq:
* min_silence_duration_ms = duration_ms -> duration_ms
*/
static VALUE
ruby_whisper_vad_params_set_min_silence_duration_ms(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.min_silence_duration_ms = NUM2INT(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_min_silence_duration_ms(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return INT2NUM(rwvp->params.min_silence_duration_ms);
}
/*
* Max duration of a speech segment before forcing a new segment.
*
* call-seq:
* max_speech_duration_s = duration_s -> duration_s
*/
static VALUE
ruby_whisper_vad_params_set_max_speech_duration_s(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.max_speech_duration_s = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_max_speech_duration_s(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return DBL2NUM(rwvp->params.max_speech_duration_s);
}
/*
* Padding added before and after speech segments.
*
* call-seq:
* speech_pad_ms = pad_ms -> pad_ms
*/
static VALUE
ruby_whisper_vad_params_set_speech_pad_ms(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.speech_pad_ms = NUM2INT(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_speech_pad_ms(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return INT2NUM(rwvp->params.speech_pad_ms);
}
/*
* Overlap in seconds when copying audio samples from speech segment.
*
* call-seq:
* samples_overlap = overlap -> overlap
*/
static VALUE
ruby_whisper_vad_params_set_samples_overlap(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.samples_overlap = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_samples_overlap(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return DBL2NUM(rwvp->params.samples_overlap);
}
static VALUE
ruby_whisper_vad_params_equal(VALUE self, VALUE other)
{
ruby_whisper_vad_params *rwvp1;
ruby_whisper_vad_params *rwvp2;
if (self == other) {
return Qtrue;
}
if (!rb_obj_is_kind_of(other, cVADParams)) {
return Qfalse;
}
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp1);
TypedData_Get_Struct(other, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp2);
if (rwvp1->params.threshold != rwvp2->params.threshold) {
return Qfalse;
}
if (rwvp1->params.min_speech_duration_ms != rwvp2->params.min_speech_duration_ms) {
return Qfalse;
}
if (rwvp1->params.min_silence_duration_ms != rwvp2->params.min_silence_duration_ms) {
return Qfalse;
}
if (rwvp1->params.max_speech_duration_s != rwvp2->params.max_speech_duration_s) {
return Qfalse;
}
if (rwvp1->params.speech_pad_ms != rwvp2->params.speech_pad_ms) {
return Qfalse;
}
if (rwvp1->params.samples_overlap != rwvp2->params.samples_overlap) {
return Qfalse;
}
return Qtrue;
}
#define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \
ruby_whisper_vad_params_set_ ## param_name(self, value); \
continue; \
}
VALUE
ruby_whisper_vad_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[NUM_PARAMS] = {Qundef};
VALUE value;
ruby_whisper_vad_params *rwvp;
ID id;
int i;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return self;
}
rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
for (i = 0; i < NUM_PARAMS; i++) {
id = param_names[i];
value = values[i];
if (value == Qundef) {
continue;
}
SET_PARAM_IF_SAME(threshold)
SET_PARAM_IF_SAME(min_speech_duration_ms)
SET_PARAM_IF_SAME(min_silence_duration_ms)
SET_PARAM_IF_SAME(max_speech_duration_s)
SET_PARAM_IF_SAME(speech_pad_ms)
SET_PARAM_IF_SAME(samples_overlap)
}
return self;
}
#undef SET_PARAM_IF_SAME
void
init_ruby_whisper_vad_params(VALUE *mVAD)
{
cVADParams = rb_define_class_under(*mVAD, "Params", rb_cObject);
rb_define_alloc_func(cVADParams, ruby_whisper_vad_params_s_allocate);
rb_define_method(cVADParams, "initialize", ruby_whisper_vad_params_initialize, -1);
DEFINE_PARAM(threshold, 0)
DEFINE_PARAM(min_speech_duration_ms, 1)
DEFINE_PARAM(min_silence_duration_ms, 2)
DEFINE_PARAM(max_speech_duration_s, 3)
DEFINE_PARAM(speech_pad_ms, 4)
DEFINE_PARAM(samples_overlap, 5)
rb_define_method(cVADParams, "==", ruby_whisper_vad_params_equal, 1);
}
#undef DEFINE_PARAM
#undef NUM_PARAMS

View File

@ -1,10 +1,5 @@
require "pathname"
root = Pathname("..")/".."
ignored_dirs = %w[
.devops
.github
ci
examples/wchess/wchess.wasm
examples/whisper.android
examples/whisper.android.java
@ -14,7 +9,7 @@ ignored_dirs = %w[
models
samples
scripts
].collect {|dir| root/dir}
]
ignored_files = %w[
AUTHORS
Makefile
@ -22,19 +17,18 @@ ignored_files = %w[
README_sycl.md
.gitignore
.gitmodules
.dockerignore
whisper.nvim
twitch.sh
yt-wsp.sh
close-issue.yml
]
EXTSOURCES =
`git ls-files -z #{root}`.split("\x0")
.collect {|file| Pathname(file)}
.reject {|file|
ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} ||
ignored_files.include?(file.basename.to_path) ||
(file.descend.to_a[1] != root && file.descend.to_a[1] != Pathname("..")/"javascript")
`git ls-files -z ../..`.split("\x0")
.select {|file|
basename = File.basename(file)
ignored_dirs.all? {|dir| !file.start_with?("../../#{dir}")} &&
!ignored_files.include?(basename) &&
(file.start_with?("../..") || file.start_with?("../javascript")) &&
(!file.start_with?("../../.github/") || basename == "bindings-ruby.yml")
}
.collect(&:to_path)

View File

@ -1,15 +0,0 @@
module Whisper
class Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end
def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
end
end

View File

@ -130,44 +130,6 @@ module Whisper
end
end
class ZipURI < URI
def cache
zip_path = super
dest = unzipped_path
return if dest.exist? && dest.mtime >= zip_path.mtime
escaping dest do
system "unzip", "-q", "-d", zip_path.dirname.to_path, zip_path.to_path, exception: true
end
zip_path
end
def clear_cache
super
unzipped_path.rmtree if unzipped_path.exist?
end
private
def unzipped_path
cache_path.sub_ext("")
end
def escaping(path)
escaped = Pathname("#{path}.removing")
if path.exist?
escaped.rmtree if escaped.exist?
path.rename escaped
end
yield
ensure
if path.exist?
escaped.rmtree if escaped.exist?
else
escaped.rename path if escaped.exist?
end
end
end
@pre_converted_models = %w[
tiny
tiny.en
@ -203,31 +165,8 @@ module Whisper
models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
}
%w[
silero-v5.1.2
].each do |name|
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
end
@coreml_compiled_models = %w[
tiny
tiny.en
base
base.en
small
small.en
medium
medium.en
large-v1
large-v2
large-v3
large-v3-turbo
].each_with_object({}) do |name, models|
models[@pre_converted_models[name]] = ZipURI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}-encoder.mlmodelc.zip")
end
class << self
attr_reader :pre_converted_models, :coreml_compiled_models
attr_reader :pre_converted_models
end
end
end

View File

@ -1,58 +0,0 @@
module Whisper
class Segment
SRT_ESCAPES = {
"&" => "&amp;",
"<" => "&lt;",
">" => "&gt;",
}
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
def to_srt_cue
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
end
def to_webvtt_cue
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
end
private
def time_to_a(time)
sec, decimal_part = time.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
[hour, min, sec, decimal_part]
end
def srt_time(time)
"%02d:%02d:%02d,%03d" % time_to_a(time)
end
def srt_start_time
srt_time(start_time)
end
def srt_end_time
srt_time(end_time)
end
def srt_text
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
end
def webvtt_time(time)
"%02d:%02d:%02d.%03d" % time_to_a(time)
end
def webvtt_start_time
webvtt_time(start_time)
end
def webvtt_end_time
webvtt_time(end_time)
end
alias webvtt_text srt_text
end
end

View File

@ -10,7 +10,6 @@ module Whisper
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
VERSION: String
LOG_LEVEL_NONE: Integer
LOG_LEVEL_INFO: Integer
LOG_LEVEL_WARN: Integer
@ -23,22 +22,21 @@ module Whisper
def self.lang_str: (Integer id) -> String
def self.lang_str_full: (Integer id) -> String
def self.log_set: (log_callback, Object? user_data) -> log_callback
def self.system_info_str: () -> String
class Context
def self.new: (String | path | ::URI::HTTP) -> instance
def self.new: (path | ::URI::HTTP) -> instance
# transcribe a single file
# can emit to a block results
#
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
#
def transcribe: (string, Params, ?n_processors: Integer) -> self
| (string, Params, ?n_processors: Integer) { (String) -> void } -> self
def transcribe: (string, Params) -> self
| (string, Params) { (String) -> void } -> self
def model_n_vocab: () -> Integer
def model_n_audio_ctx: () -> Integer
@ -51,16 +49,16 @@ module Whisper
# Yields each Whisper::Segment:
#
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
#
# Returns an Enumerator if no block given:
#
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
@ -75,25 +73,25 @@ module Whisper
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer) -> Integer
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer) -> Integer
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
#
# full_get_segment_speacker_turn_next(3) # => true
# full_get_segment_speacker_turn_next(3) # => true
#
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
# Text of a segment indexed by +segment_index+.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer) -> String
@ -117,9 +115,6 @@ module Whisper
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
def to_srt: () -> String
def to_webvtt: () -> String
end
class Params
@ -135,7 +130,6 @@ module Whisper
?suppress_blank: boolish,
?suppress_nst: boolish,
?token_timestamps: boolish,
?max_len: Integer,
?split_on_word: boolish,
?initial_prompt: string | nil,
?diarize: boolish,
@ -156,10 +150,7 @@ module Whisper
?encoder_begin_callback: encoder_begin_callback,
?encoder_begin_callback_user_data: Object,
?abort_callback: abort_callback,
?abort_callback_user_data: Object,
?vad: boolish,
?vad_model_path: path | URI,
?vad_params: Whisper::VAD::Params
?abort_callback_user_data: Object
) -> instance
# params.language = "auto" | "en", etc...
@ -223,12 +214,6 @@ module Whisper
#
def token_timestamps: () -> (true | false)
def max_len=: (Integer) -> Integer
# max segment length in characters.
#
def max_len: () -> Integer
def split_on_word=: (boolish) -> boolish
# If true, split on word rather than on token (when used with max_len).
@ -293,9 +278,9 @@ module Whisper
# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback: () -> (new_segment_callback | nil)
@ -308,9 +293,9 @@ module Whisper
# Sets progress callback, called on each progress update.
#
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
#
# +progress+ is an Integer between 0 and 100.
#
@ -338,9 +323,9 @@ module Whisper
# Sets abort callback, called to check if the process should be aborted.
#
# params.abort_callback = ->(user_data) {
# # ...
# }
# params.abort_callback = ->(user_data) {
# # ...
# }
#
#
def abort_callback=: (abort_callback) -> abort_callback
@ -353,25 +338,11 @@ module Whisper
def abort_callback_user_data: () -> Object
# Enable VAD
#
def vad=: (boolish) -> boolish
def vad: () -> (true | false)
# Path to the VAD model
def vad_model_path=: (path | URI | nil) -> (path | URI | nil)
def vad_model_path: () -> (String | nil)
def vad_params=: (Whisper::VAD::Params) -> Whisper::VAD::Params
def vad_params: () -> (Whisper::VAD::Params)
# Hook called on new segment. Yields each Whisper::Segment.
#
# whisper.on_new_segment do |segment|
# # ...
# end
# whisper.on_new_segment do |segment|
# # ...
# end
#
def on_new_segment: { (Segment) -> void } -> void
@ -385,20 +356,19 @@ module Whisper
# Call block to determine whether abort or not. Return +true+ when you want to abort.
#
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# end
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# end
# end
#
def abort_on: { (Object user_data) -> boolish } -> void
end
class Model
def self.pre_converted_models: () -> Hash[String, Model::URI]
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
def self.new: () -> instance
def n_vocab: () -> Integer
def n_audio_ctx: () -> Integer
@ -418,22 +388,9 @@ module Whisper
def to_path: -> String
def clear_cache: -> void
end
class ZipURI < URI
def cache: () -> Pathname
def clear_cache: () -> void
end
end
class Segment
type deconstructed_keys = {
start_time: (Integer | nil),
end_time: (Integer | nil),
text: (String | nil),
no_speech_prob: (Float | nil),
speaker_turn_next: (true | false | nil)
}
# Start time in milliseconds.
#
def start_time: () -> Integer
@ -443,70 +400,10 @@ module Whisper
def end_time: () -> Integer
# Whether the next segment is predicted as a speaker turn.
def speaker_turn_next?: () -> (true | false)
def speaker_next_turn?: () -> (true | false)
def text: () -> String
def no_speech_prob: () -> Float
def to_srt_cue: () -> String
def to_webvtt_cue: () -> String
# Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
#
# whisper.each_segment do |segment|
# segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
#
# puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
# end
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next] | nil) -> deconstructed_keys
end
module VAD
class Params
def self.new: (
?threshold: Float,
?min_speech_duration_ms: Integer,
?min_silence_duration_ms: Integer,
?max_speech_duration_s: Float,
?speech_pad_ms: Integer,
?samples_overlap: Float
) -> instance
# Probability threshold to consider as speech.
#
def threshold=: (Float) -> Float
def threshold: () -> Float
# Min duration for a valid speech segment.
#
def min_speech_duration_ms=: (Integer) -> Integer
def min_speech_duration_ms: () -> Integer
# Min silence duration to consider speech as ended.
#
def min_silence_duration_ms=: (Integer) -> Integer
def min_silence_duration_ms: () -> Integer
# Max duration of a speech segment before forcing a new segment.
def max_speech_duration_s=: (Float) -> Float
def max_speech_duration_s: () -> Float
# Padding added before and after speech segments.
#
def speech_pad_ms=: (Integer) -> Integer
def speech_pad_ms: () -> Integer
# Overlap in seconds when copying audio samples from speech segment.
#
def samples_overlap=: (Float) -> Float
def samples_overlap: () -> Float
def ==: (Params) -> (true | false)
end
end
class Error < StandardError

View File

@ -1,51 +0,0 @@
require_relative "helper"
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < TestBase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert file.size > 0
assert_path_exist file.to_path
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_install
gemspec = Gem::Specification.load("whispercpp.gemspec")
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true
assert_installed dir, gemspec.version
end
end
def test_install_with_coreml
omit_unless RUBY_PLATFORM.match?(/darwin/) do
gemspec = Gem::Specification.load("whispercpp.gemspec")
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true
assert_installed dir, gemspec.version
libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib")
assert_nothing_raised do
system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true
end
assert_match(/COREML = 1/, `ruby -I #{libdir.shellescape} -r whisper -e 'puts Whisper.system_info_str'`)
end
end
end
private
def assert_installed(dir, version)
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
end
end
end

View File

@ -1,146 +0,0 @@
require_relative "helper"
class TestSegment < TestBase
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
end
end
def test_enumerator
enum = whisper.each_segment
assert_instance_of Enumerator, enum
enum.to_a.each_with_index do |segment, index|
assert_instance_of Whisper::Segment, segment
assert_kind_of Integer, index
end
end
def test_start_time
i = 0
whisper.each_segment do |segment|
assert_equal 0, segment.start_time if i == 0
i += 1
end
end
def test_end_time
i = 0
whisper.each_segment do |segment|
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
i += 1
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
index = 0
params.on_new_segment do |segment|
assert_instance_of Whisper::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
end
index += 1
end
whisper.transcribe(AUDIO, params)
assert_equal 0, seg.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
end
def test_on_new_segment_twice
params = Whisper::Params.new
seg = nil
params.on_new_segment do |segment|
seg = segment
return
end
params.on_new_segment do |segment|
assert_same seg, segment
return
end
whisper.transcribe(AUDIO, params)
end
def test_transcription_after_segment_retrieved
params = Whisper::Params.new
segment = whisper.each_segment.first
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
whisper.transcribe(AUDIO, Whisper::Params.new(offset: 5000))
assert_not_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
assert_match(/what you can do for your country/i, segment.text)
end
def test_pattern_matching
segment = whisper.each_segment.first
segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
assert_equal segment.start_time, start_time
assert_equal segment.end_time, end_time
assert_equal segment.text, text
assert_equal segment.no_speech_prob, no_speech_prob
assert_equal segment.speaker_turn_next?, speaker_turn_next
end
def test_pattern_matching_partial
segment = whisper.each_segment.first
segment => {start_time:, end_time:, text:}
assert_equal segment.start_time, start_time
assert_equal segment.end_time, end_time
assert_equal segment.text, text
end
def test_deconstruct_keys
segment = whisper.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text,
no_speech_prob: segment.no_speech_prob,
speaker_turn_next: segment.speaker_turn_next?
}
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next])
end
def test_deconstruct_keys_non_existent
omit "Undefined behavior"
segment = whisper.each_segment.first
assert_equal({}, segment.deconstruct_keys([:non_existent]))
end
def test_deconstruct_keys_too_many_keys
omit "Undefined behavior"
segment = whisper.each_segment.first
assert_equal({}, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next, :extra_key]))
end
def test_deconstruct_keys_includes_non_existent_keys_not_too_many
omit "Undefined behavior"
segment = whisper.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text,
no_speech_prob: segment.no_speech_prob
}
assert_equal(expected, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :non_existent]))
end
end

View File

@ -1,19 +0,0 @@
require_relative "helper"
class TestVAD < TestBase
def setup
@whisper = Whisper::Context.new("base.en")
vad_params = Whisper::VAD::Params.new
@params = Whisper::Params.new(
vad: true,
vad_model_path: "silero-v5.1.2",
vad_params:
)
end
def test_transcribe
@whisper.transcribe(TestBase::AUDIO, @params) do |text|
assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
end
end
end

View File

@ -1,103 +0,0 @@
require_relative "helper"
class TestVADParams < TestBase
PARAM_NAMES = [
:threshold,
:min_speech_duration_ms,
:min_silence_duration_ms,
:max_speech_duration_s,
:speech_pad_ms,
:samples_overlap
]
def setup
@params = Whisper::VAD::Params.new
end
def test_new
params = Whisper::VAD::Params.new
assert_kind_of Whisper::VAD::Params, params
end
def test_threshold
assert_in_delta @params.threshold, 0.5
@params.threshold = 0.7
assert_in_delta @params.threshold, 0.7
end
def test_min_speech_duration
pend
end
def test_min_speech_duration_ms
assert_equal 250, @params.min_speech_duration_ms
@params.min_speech_duration_ms = 500
assert_equal 500, @params.min_speech_duration_ms
end
def test_min_silence_duration_ms
assert_equal 100, @params.min_silence_duration_ms
@params.min_silence_duration_ms = 200
assert_equal 200, @params.min_silence_duration_ms
end
def test_max_speech_duration
pend
end
def test_max_speech_duration_s
assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX
@params.max_speech_duration_s = 60.0
assert_equal 60.0, @params.max_speech_duration_s
end
def test_speech_pad_ms
assert_equal 30, @params.speech_pad_ms
@params.speech_pad_ms = 50
assert_equal 50, @params.speech_pad_ms
end
def test_samples_overlap
assert_in_delta @params.samples_overlap, 0.1
@params.samples_overlap = 0.5
assert_in_delta @params.samples_overlap, 0.5
end
def test_equal
assert_equal @params, Whisper::VAD::Params.new
end
def test_new_with_kw_args
params = Whisper::VAD::Params.new(threshold: 0.7)
assert_in_delta params.threshold, 0.7
assert_equal 250, params.min_speech_duration_ms
end
def test_new_with_kw_args_non_existent
assert_raise ArgumentError do
Whisper::VAD::Params.new(non_existent: "value")
end
end
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
def test_new_with_kw_args_default_values(param)
default_value = @params.send(param)
value = default_value + 1
params = Whisper::VAD::Params.new(param => value)
if Float === value
assert_in_delta value, params.send(param)
else
assert_equal value, params.send(param)
end
PARAM_NAMES.reject {|name| name == param}.each do |name|
expected = @params.send(name)
actual = params.send(name)
if Float === expected
assert_in_delta expected, actual
else
assert_equal expected, actual
end
end
end
end

View File

@ -3,7 +3,7 @@ require "whisper"
require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "fixtures", "jfk.wav")
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self
def whisper
@ -21,4 +21,15 @@ class TestBase < Test::Unit::TestCase
def whisper
self.class.whisper
end
module BuildOptions
load "ext/options.rb", self
Options.include self
def enable_config(name)
end
def arg_config(name)
end
end
end

View File

@ -106,13 +106,4 @@ class TestModel < TestBase
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_coreml_model_auto_download
uri = Whisper::Model.coreml_compiled_models[Whisper::Model.pre_converted_models["tiny"]]
model_path = Pathname(uri.to_path).sub_ext("")
model_path.rmtree if model_path.exist?
uri.cache
assert_path_exist model_path
end
end

View File

@ -0,0 +1,46 @@
require_relative "helper"
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < TestBase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert file.size > 0
assert_path_exist file.to_path
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_install
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
filename = match_data[1]
version = match_data[2]
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
assert_installed dir, version
end
end
private
def assert_installed(dir, version)
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
end
end
def test_build_options
options = BuildOptions::Options.new
assert_empty options.missing_options
if ENV["TEST_EXTRA_OPTIONS"] == "1"
assert_empty options.extra_options
end
end
end

View File

@ -13,7 +13,6 @@ class TestParams < TestBase
:suppress_blank,
:suppress_nst,
:token_timestamps,
:max_len,
:split_on_word,
:initial_prompt,
:diarize,
@ -33,9 +32,6 @@ class TestParams < TestBase
:progress_callback_user_data,
:abort_callback,
:abort_callback_user_data,
:vad,
:vad_model_path,
:vad_params,
]
def setup
@ -140,13 +136,6 @@ class TestParams < TestBase
assert !@params.token_timestamps
end
def test_max_len
@params.max_len = 42
assert_equal @params.max_len, 42
@params.max_len = 0
assert_equal @params.max_len, 0
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@ -202,50 +191,6 @@ class TestParams < TestBase
assert_in_delta 0.2, @params.no_speech_thold
end
def test_vad
assert_false @params.vad
@params.vad = true
assert_true @params.vad
end
def test_vad_model_path
assert_nil @params.vad_model_path
@params.vad_model_path = "silero-v5.1.2"
assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path
end
def test_vad_model_path_with_nil
@params.vad_model_path = "silero-v5.1.2"
@params.vad_model_path = nil
assert_nil @params.vad_model_path
end
def test_vad_model_path_with_invalid
assert_raise TypeError do
@params.vad_model_path = Object.new
end
end
def test_vad_model_path_with_URI_string
@params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin"
assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
end
def test_vad_model_path_with_URI
@params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin")
assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
end
def test_vad_params
assert_kind_of Whisper::VAD::Params, @params.vad_params
default_params = @params.vad_params
assert_same default_params, @params.vad_params
assert_equal 0.5, default_params.threshold
new_params = Whisper::VAD::Params.new
@params.vad_params = new_params
assert_same new_params, @params.vad_params
end
def test_new_with_kw_args
params = Whisper::Params.new(language: "es")
assert_equal "es", params.language
@ -280,10 +225,6 @@ class TestParams < TestBase
proc {}
in [/_user_data\Z/, *]
Object.new
in [:vad_model_path, *]
Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
in [:vad_params, *]
Whisper::VAD::Params.new
end
params = Whisper::Params.new(param => value)
if Float === value

View File

@ -0,0 +1,74 @@
require_relative "helper"
class TestSegment < TestBase
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
end
end
def test_enumerator
enum = whisper.each_segment
assert_instance_of Enumerator, enum
enum.to_a.each_with_index do |segment, index|
assert_instance_of Whisper::Segment, segment
assert_kind_of Integer, index
end
end
def test_start_time
i = 0
whisper.each_segment do |segment|
assert_equal 0, segment.start_time if i == 0
i += 1
end
end
def test_end_time
i = 0
whisper.each_segment do |segment|
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
i += 1
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
index = 0
params.on_new_segment do |segment|
assert_instance_of Whisper::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
end
index += 1
end
whisper.transcribe(AUDIO, params)
assert_equal 0, seg.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
end
def test_on_new_segment_twice
params = Whisper::Params.new
seg = nil
params.on_new_segment do |segment|
seg = segment
return
end
params.on_new_segment do |segment|
assert_same seg, segment
return
end
whisper.transcribe(AUDIO, params)
end
end

View File

@ -20,24 +20,6 @@ class TestWhisper < TestBase
}
end
def test_transcribe_non_parallel
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
@whisper.transcribe(AUDIO, params, n_processors: 1) {|text|
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
}
end
def test_transcribe_n_processors
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
@whisper.transcribe(AUDIO, params, n_processors: 4) {|text|
assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
}
end
sub_test_case "After transcription" do
def test_full_n_segments
assert_equal 1, whisper.full_n_segments
@ -112,14 +94,6 @@ class TestWhisper < TestBase
end
end
def test_system_info_str
assert_match(/\AWHISPER : COREML = \d | OPENVINO = \d |/, Whisper.system_info_str)
end
def test_version
assert_kind_of String, Whisper::VERSION
end
def test_log_set
user_data = Object.new
logs = []
@ -249,48 +223,4 @@ class TestWhisper < TestBase
assert_match(/for your country/i, text)
end
end
def test_to_srt
whisper = Whisper::Context.new("base.en")
whisper.transcribe AUDIO, @params
lines = whisper.to_srt.lines
assert_match(/\A\d+\n/, lines[0])
assert_match(/\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}\n/, lines[1])
assert_match(/ask not what your country can do for you, ask what you can do for your country/, lines[2])
end
def test_to_webvtt
whisper = Whisper::Context.new("base.en")
whisper.transcribe AUDIO, @params
lines = whisper.to_webvtt.lines
assert_equal "WEBVTT\n", lines[0]
assert_equal "\n", lines[1]
assert_match(/\A\d+\n/, lines[2])
assert_match(/\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}\n/, lines[3])
assert_match(/ask not what your country can do for you, ask what you can do for your country/, lines[4])
end
sub_test_case "Format needs escape" do
def setup
@whisper = Whisper::Context.new("base.en")
@whisper.transcribe AUDIO, Whisper::Params.new
segment = @whisper.each_segment.first
segment.define_singleton_method :text do
"& so my fellow Americans --> ask not what your country can do for you <-- ask what you can do for your country."
end
@whisper.define_singleton_method :each_segment do
Enumerator.new(3) {|yielder| 3.times {yielder << segment}}
end
end
def test_to_srt_escape
assert_equal "&amp; so my fellow Americans --&gt; ask not what your country can do for you &lt;-- ask what you can do for your country.\n", @whisper.to_srt.lines[2]
end
def test_to_webvtt_escape
assert_equal "&amp; so my fellow Americans --&gt; ask not what your country can do for you &lt;-- ask what you can do for your country.\n", @whisper.to_webvtt.lines[4]
end
end
end

View File

@ -3,7 +3,8 @@ require_relative "extsources"
Gem::Specification.new do |s|
s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.3'
s.version = '1.3.2'
s.date = '2025-05-11'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md']
@ -20,7 +21,7 @@ Gem::Specification.new do |s|
}
s.summary = %q{Ruby whisper.cpp bindings}
s.test_files = s.files.select {|file| file.start_with? "test/"}
s.test_files = s.files.select {|file| file.start_with? "tests/"}
s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0'

View File

@ -15,7 +15,6 @@ GGML_METAL_EMBED_LIBRARY=ON
GGML_BLAS_DEFAULT=ON
GGML_METAL_USE_BF16=ON
GGML_OPENMP=OFF
BUILD_STATIC_XCFRAMEWORK=${BUILD_STATIC_XCFRAMEWORK:-OFF}
COMMON_C_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
COMMON_CXX_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
@ -328,15 +327,6 @@ combine_static_libraries() {
arch_flags+=" -arch $arch"
done
if [[ "${BUILD_STATIC_XCFRAMEWORK}" == "ON" ]]; then
echo "Packaging static framework for ${platform}."
mkdir -p "$(dirname "${base_dir}/${output_lib}")"
cp "${temp_dir}/combined.a" "${base_dir}/${output_lib}"
rm -rf "${temp_dir}"
return
fi
# Create dynamic library
echo "Creating dynamic library for ${platform}."
xcrun -sdk $sdk clang++ -dynamiclib \
@ -539,20 +529,6 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
# Create XCFramework with correct debug symbols paths
echo "Creating XCFramework..."
if [[ "${BUILD_STATIC_XCFRAMEWORK}" == "ON" ]]; then
xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-sim/framework/whisper.framework \
-framework $(pwd)/build-ios-device/framework/whisper.framework \
-framework $(pwd)/build-macos/framework/whisper.framework \
-framework $(pwd)/build-visionos/framework/whisper.framework \
-framework $(pwd)/build-visionos-sim/framework/whisper.framework \
-framework $(pwd)/build-tvos-device/framework/whisper.framework \
-framework $(pwd)/build-tvos-sim/framework/whisper.framework \
-output $(pwd)/build-apple/whisper.xcframework
exit 0
fi
xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-sim/framework/whisper.framework \
-debug-symbols $(pwd)/build-ios-sim/dSYMs/whisper.dSYM \

View File

@ -105,7 +105,6 @@ else()
add_subdirectory(bench)
add_subdirectory(server)
add_subdirectory(quantize)
add_subdirectory(vad-speech-segments)
if (WHISPER_SDL2)
add_subdirectory(stream)
add_subdirectory(command)

View File

@ -1,10 +1,8 @@
# whisper.cpp Node.js addon
# addon
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
It can be used as a reference for using the whisper.cpp project in other node projects.
This addon now supports **Voice Activity Detection (VAD)** for improved transcription performance.
## Install
```shell
@ -28,88 +26,12 @@ For Electron addon and cmake-js options, you can see [cmake-js](https://github.c
## Run
### Basic Usage
```shell
cd examples/addon.node
node index.js --language='language' --model='model-path' --fname_inp='file-path'
```
### VAD (Voice Activity Detection) Usage
Because this is a simple Demo, only the above parameters are set in the node environment.
Run the VAD example with performance comparison:
```shell
node vad-example.js
```
## Voice Activity Detection (VAD) Support
VAD can significantly improve transcription performance by only processing speech segments, which is especially beneficial for audio files with long periods of silence.
### VAD Model Setup
Before using VAD, download a VAD model:
```shell
# From the whisper.cpp root directory
./models/download-vad-model.sh silero-v5.1.2
```
### VAD Parameters
All VAD parameters are optional and have sensible defaults:
- `vad`: Enable VAD (default: false)
- `vad_model`: Path to VAD model file (required when VAD enabled)
- `vad_threshold`: Speech detection threshold 0.0-1.0 (default: 0.5)
- `vad_min_speech_duration_ms`: Min speech duration in ms (default: 250)
- `vad_min_silence_duration_ms`: Min silence duration in ms (default: 100)
- `vad_max_speech_duration_s`: Max speech duration in seconds (default: FLT_MAX)
- `vad_speech_pad_ms`: Speech padding in ms (default: 30)
- `vad_samples_overlap`: Sample overlap 0.0-1.0 (default: 0.1)
### JavaScript API Example
```javascript
const path = require("path");
const { whisper } = require(path.join(__dirname, "../../build/Release/addon.node"));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
// With VAD enabled
const vadParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
vad: true,
vad_model: path.join(__dirname, "../../models/ggml-silero-v5.1.2.bin"),
vad_threshold: 0.5,
progress_callback: (progress) => console.log(`Progress: ${progress}%`)
};
whisperAsync(vadParams).then(result => console.log(result));
```
## Supported Parameters
Both traditional whisper.cpp parameters and new VAD parameters are supported:
- `language`: Language code (e.g., "en", "es", "fr")
- `model`: Path to whisper model file
- `fname_inp`: Path to input audio file
- `use_gpu`: Enable GPU acceleration (default: true)
- `flash_attn`: Enable flash attention (default: false)
- `no_prints`: Disable console output (default: false)
- `no_timestamps`: Disable timestamps (default: false)
- `detect_language`: Auto-detect language (default: false)
- `audio_ctx`: Audio context size (default: 0)
- `max_len`: Maximum segment length (default: 0)
- `max_context`: Maximum context size (default: -1)
- `prompt`: Initial prompt for decoder
- `comma_in_time`: Use comma in timestamps (default: true)
- `print_progress`: Print progress info (default: false)
- `progress_callback`: Progress callback function
- VAD parameters (see above section)
Other parameters can also be specified in the node environment.

View File

@ -1,133 +1,37 @@
const { join } = require('path');
const { whisper } = require('../../../build/Release/addon.node');
const { promisify } = require('util');
const path = require("path");
const { whisper } = require(path.join(
__dirname,
"../../../build/Release/addon.node"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
const commonParams = {
language: 'en',
model: join(__dirname, '../../../models/ggml-base.en.bin'),
fname_inp: join(__dirname, '../../../samples/jfk.wav'),
const whisperParamsMock = {
language: "en",
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true,
flash_attn: false,
no_prints: true,
comma_in_time: false,
translate: true,
no_timestamps: false,
detect_language: false,
audio_ctx: 0,
max_len: 0
max_len: 0,
prompt: "",
print_progress: false,
progress_callback: (progress) => {
console.log(`Progress: ${progress}`);
},
max_context: -1
};
describe('Whisper.cpp Node.js addon with VAD support', () => {
test('Basic whisper transcription without VAD', async () => {
const params = {
...commonParams,
vad: false
};
describe("Run whisper.node", () => {
test("it should receive a non-empty value", async () => {
let result = await whisperAsync(whisperParamsMock);
const result = await whisperAsync(params);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
expect(result.transcription.length).toBeGreaterThan(0);
// Check that we got some transcription text
const text = result.transcription.map(segment => segment[2]).join(' ');
expect(text.length).toBeGreaterThan(0);
expect(text.toLowerCase()).toContain('ask not');
}, 30000);
test('VAD parameters validation', async () => {
// Test with invalid VAD model - should return empty transcription
const invalidParams = {
...commonParams,
vad: true,
vad_model: 'non-existent-model.bin',
vad_threshold: 0.5
};
// This should handle the error gracefully and return empty transcription
const result = await whisperAsync(invalidParams);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
// When VAD model doesn't exist, it should return empty transcription
expect(result.transcription.length).toBe(0);
}, 10000);
test('VAD parameter parsing', async () => {
// Test that VAD parameters are properly parsed (even if VAD model doesn't exist)
const vadParams = {
...commonParams,
vad: false, // Disabled so no model required
vad_threshold: 0.7,
vad_min_speech_duration_ms: 300,
vad_min_silence_duration_ms: 150,
vad_max_speech_duration_s: 45.0,
vad_speech_pad_ms: 50,
vad_samples_overlap: 0.15
};
const result = await whisperAsync(vadParams);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
}, 30000);
test('Progress callback with VAD disabled', async () => {
let progressCalled = false;
let lastProgress = 0;
const params = {
...commonParams,
vad: false,
progress_callback: (progress) => {
progressCalled = true;
lastProgress = progress;
expect(progress).toBeGreaterThanOrEqual(0);
expect(progress).toBeLessThanOrEqual(100);
}
};
const result = await whisperAsync(params);
expect(progressCalled).toBe(true);
expect(lastProgress).toBe(100);
expect(typeof result).toBe('object');
}, 30000);
test('Language detection without VAD', async () => {
const params = {
...commonParams,
vad: false,
detect_language: true,
language: 'auto'
};
const result = await whisperAsync(params);
expect(typeof result).toBe('object');
expect(typeof result.language).toBe('string');
expect(result.language.length).toBeGreaterThan(0);
}, 30000);
test('Basic transcription with all VAD parameters set', async () => {
// Test with VAD disabled but all parameters set to ensure no crashes
const params = {
...commonParams,
vad: false, // Disabled so it works without VAD model
vad_model: '', // Empty model path
vad_threshold: 0.6,
vad_min_speech_duration_ms: 200,
vad_min_silence_duration_ms: 80,
vad_max_speech_duration_s: 25.0,
vad_speech_pad_ms: 40,
vad_samples_overlap: 0.08
};
const result = await whisperAsync(params);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
expect(result.transcription.length).toBeGreaterThan(0);
}, 30000);
expect(result.length).toBeGreaterThan(0);
}, 10000);
});

View File

@ -9,7 +9,6 @@
#include <vector>
#include <cmath>
#include <cstdint>
#include <cfloat>
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -39,7 +38,6 @@ struct whisper_params {
bool print_progress = false;
bool no_timestamps = false;
bool no_prints = false;
bool detect_language= false;
bool use_gpu = true;
bool flash_attn = false;
bool comma_in_time = true;
@ -52,16 +50,6 @@ struct whisper_params {
std::vector<std::string> fname_out = {};
std::vector<float> pcmf32 = {}; // mono-channel F32 PCM
// Voice Activity Detection (VAD) parameters
bool vad = false;
std::string vad_model = "";
float vad_threshold = 0.5f;
int vad_min_speech_duration_ms = 250;
int vad_min_silence_duration_ms = 100;
float vad_max_speech_duration_s = FLT_MAX;
int vad_speech_pad_ms = 30;
float vad_samples_overlap = 0.1f;
};
struct whisper_print_user_data {
@ -94,7 +82,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps && !params.no_prints) {
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
@ -125,14 +113,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
// colorful print bug
//
if (!params.no_prints) {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
// with timestamps or speakers: each segment on new line
if ((!params.no_timestamps || params.diarize) && !params.no_prints) {
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
@ -142,11 +128,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
void cb_log_disable(enum ggml_log_level, const char *, void *) {}
struct whisper_result {
std::vector<std::vector<std::string>> segments;
std::string language;
};
class ProgressWorker : public Napi::AsyncWorker {
public:
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
@ -177,27 +158,15 @@ class ProgressWorker : public Napi::AsyncWorker {
void OnOK() override {
Napi::HandleScope scope(Env());
if (params.detect_language) {
Napi::Object resultObj = Napi::Object::New(Env());
resultObj.Set("language", Napi::String::New(Env(), result.language));
Callback().Call({Env().Null(), resultObj});
}
Napi::Object returnObj = Napi::Object::New(Env());
if (!result.language.empty()) {
returnObj.Set("language", Napi::String::New(Env(), result.language));
}
Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size());
for (uint64_t i = 0; i < result.segments.size(); ++i) {
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result.segments[i][j]);
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
transcriptionArray[i] = tmp;
}
returnObj.Set("transcription", transcriptionArray);
Callback().Call({Env().Null(), returnObj});
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
// Progress callback function - using thread-safe function
@ -214,12 +183,12 @@ class ProgressWorker : public Napi::AsyncWorker {
private:
whisper_params params;
whisper_result result;
std::vector<std::vector<std::string>> result;
Napi::Env env;
Napi::ThreadSafeFunction tsfn;
// Custom run function with progress callback support
int run_with_progress(whisper_params &params, whisper_result & result) {
int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}
@ -308,8 +277,7 @@ class ProgressWorker : public Napi::AsyncWorker {
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.detect_language ? "auto" : params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
@ -344,38 +312,34 @@ class ProgressWorker : public Napi::AsyncWorker {
};
wparams.progress_callback_user_data = this;
// Set VAD parameters
wparams.vad = params.vad;
wparams.vad_model_path = params.vad_model.c_str();
// Abort mechanism example
{
static bool is_aborted = false; // Note: this should be atomic to avoid data races
wparams.vad_params.threshold = params.vad_threshold;
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
}
}
}
if (params.detect_language || params.language == "auto") {
result.language = whisper_lang_str(whisper_full_lang_id(ctx));
}
const int n_segments = whisper_full_n_segments(ctx);
result.segments.resize(n_segments);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result.segments[i].emplace_back(text);
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
@ -396,46 +360,13 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = true;
if (whisper_params.Has("use_gpu") && whisper_params.Get("use_gpu").IsBoolean()) {
use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
}
bool flash_attn = false;
if (whisper_params.Has("flash_attn") && whisper_params.Get("flash_attn").IsBoolean()) {
flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
}
bool no_prints = false;
if (whisper_params.Has("no_prints") && whisper_params.Get("no_prints").IsBoolean()) {
no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
}
bool no_timestamps = false;
if (whisper_params.Has("no_timestamps") && whisper_params.Get("no_timestamps").IsBoolean()) {
no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
}
bool detect_language = false;
if (whisper_params.Has("detect_language") && whisper_params.Get("detect_language").IsBoolean()) {
detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
}
int32_t audio_ctx = 0;
if (whisper_params.Has("audio_ctx") && whisper_params.Get("audio_ctx").IsNumber()) {
audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
}
bool comma_in_time = true;
if (whisper_params.Has("comma_in_time") && whisper_params.Get("comma_in_time").IsBoolean()) {
comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
}
int32_t max_len = 0;
if (whisper_params.Has("max_len") && whisper_params.Get("max_len").IsNumber()) {
max_len = whisper_params.Get("max_len").As<Napi::Number>();
}
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
// Add support for max_context
int32_t max_context = -1;
@ -451,7 +382,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
// Add support for print_progress
bool print_progress = false;
if (whisper_params.Has("print_progress") && whisper_params.Get("print_progress").IsBoolean()) {
if (whisper_params.Has("print_progress")) {
print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
}
// Add support for progress_callback
@ -460,47 +391,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
}
// Add support for VAD parameters
bool vad = false;
if (whisper_params.Has("vad") && whisper_params.Get("vad").IsBoolean()) {
vad = whisper_params.Get("vad").As<Napi::Boolean>();
}
std::string vad_model = "";
if (whisper_params.Has("vad_model") && whisper_params.Get("vad_model").IsString()) {
vad_model = whisper_params.Get("vad_model").As<Napi::String>();
}
float vad_threshold = 0.5f;
if (whisper_params.Has("vad_threshold") && whisper_params.Get("vad_threshold").IsNumber()) {
vad_threshold = whisper_params.Get("vad_threshold").As<Napi::Number>();
}
int vad_min_speech_duration_ms = 250;
if (whisper_params.Has("vad_min_speech_duration_ms") && whisper_params.Get("vad_min_speech_duration_ms").IsNumber()) {
vad_min_speech_duration_ms = whisper_params.Get("vad_min_speech_duration_ms").As<Napi::Number>();
}
int vad_min_silence_duration_ms = 100;
if (whisper_params.Has("vad_min_silence_duration_ms") && whisper_params.Get("vad_min_silence_duration_ms").IsNumber()) {
vad_min_silence_duration_ms = whisper_params.Get("vad_min_silence_duration_ms").As<Napi::Number>();
}
float vad_max_speech_duration_s = FLT_MAX;
if (whisper_params.Has("vad_max_speech_duration_s") && whisper_params.Get("vad_max_speech_duration_s").IsNumber()) {
vad_max_speech_duration_s = whisper_params.Get("vad_max_speech_duration_s").As<Napi::Number>();
}
int vad_speech_pad_ms = 30;
if (whisper_params.Has("vad_speech_pad_ms") && whisper_params.Get("vad_speech_pad_ms").IsNumber()) {
vad_speech_pad_ms = whisper_params.Get("vad_speech_pad_ms").As<Napi::Number>();
}
float vad_samples_overlap = 0.1f;
if (whisper_params.Has("vad_samples_overlap") && whisper_params.Get("vad_samples_overlap").IsNumber()) {
vad_samples_overlap = whisper_params.Get("vad_samples_overlap").As<Napi::Number>();
}
Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
std::vector<float> pcmf32_vec;
if (pcmf32Value.IsTypedArray()) {
@ -526,17 +416,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.max_context = max_context;
params.print_progress = print_progress;
params.prompt = prompt;
params.detect_language = detect_language;
// Set VAD parameters
params.vad = vad;
params.vad_model = vad_model;
params.vad_threshold = vad_threshold;
params.vad_min_speech_duration_ms = vad_min_speech_duration_ms;
params.vad_min_silence_duration_ms = vad_min_silence_duration_ms;
params.vad_max_speech_duration_s = vad_max_speech_duration_s;
params.vad_speech_pad_ms = vad_speech_pad_ms;
params.vad_samples_overlap = vad_samples_overlap;
Napi::Function callback = info[1].As<Napi::Function>();
// Create a new Worker class with progress callback support

View File

@ -17,7 +17,6 @@ const whisperParams = {
comma_in_time: false,
translate: true,
no_timestamps: false,
detect_language: false,
audio_ctx: 0,
max_len: 0,
progress_callback: (progress) => {
@ -32,8 +31,6 @@ const params = Object.fromEntries(
const [key, value] = item.slice(2).split("=");
if (key === "audio_ctx") {
whisperParams[key] = parseInt(value);
} else if (key === "detect_language") {
whisperParams[key] = value === "true";
} else {
whisperParams[key] = value;
}

View File

@ -1,132 +0,0 @@
const path = require("path");
const { whisper } = require(path.join(
__dirname,
"../../build/Release/addon.node"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
// Example with VAD enabled
const vadParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
use_gpu: true,
flash_attn: false,
no_prints: false,
comma_in_time: true,
translate: false,
no_timestamps: false,
detect_language: false,
audio_ctx: 0,
max_len: 0,
// VAD parameters
vad: true,
vad_model: path.join(__dirname, "../../models/ggml-silero-v5.1.2.bin"), // You need to download this model
vad_threshold: 0.5,
vad_min_speech_duration_ms: 250,
vad_min_silence_duration_ms: 100,
vad_max_speech_duration_s: 30.0,
vad_speech_pad_ms: 30,
vad_samples_overlap: 0.1,
progress_callback: (progress) => {
console.log(`VAD Transcription progress: ${progress}%`);
}
};
// Example without VAD (traditional approach)
const traditionalParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
use_gpu: true,
flash_attn: false,
no_prints: false,
comma_in_time: true,
translate: false,
no_timestamps: false,
detect_language: false,
audio_ctx: 0,
max_len: 0,
vad: false, // Explicitly disable VAD
progress_callback: (progress) => {
console.log(`Traditional transcription progress: ${progress}%`);
}
};
async function runVADExample() {
try {
console.log("=== Whisper.cpp Node.js VAD Example ===\n");
// Check if VAD model exists
const fs = require('fs');
if (!fs.existsSync(vadParams.vad_model)) {
console.log("⚠️ VAD model not found. Please download the VAD model first:");
console.log(" ./models/download-vad-model.sh silero-v5.1.2");
console.log(" Or run: python models/convert-silero-vad-to-ggml.py");
console.log("\n Falling back to traditional transcription without VAD...\n");
// Run without VAD
console.log("🎵 Running traditional transcription...");
const traditionalResult = await whisperAsync(traditionalParams);
console.log("\n📝 Traditional transcription result:");
console.log(traditionalResult);
return;
}
console.log("🎵 Running transcription with VAD enabled...");
console.log("VAD Parameters:");
console.log(` - Threshold: ${vadParams.vad_threshold}`);
console.log(` - Min speech duration: ${vadParams.vad_min_speech_duration_ms}ms`);
console.log(` - Min silence duration: ${vadParams.vad_min_silence_duration_ms}ms`);
console.log(` - Max speech duration: ${vadParams.vad_max_speech_duration_s}s`);
console.log(` - Speech padding: ${vadParams.vad_speech_pad_ms}ms`);
console.log(` - Samples overlap: ${vadParams.vad_samples_overlap}\n`);
const startTime = Date.now();
const vadResult = await whisperAsync(vadParams);
const vadDuration = Date.now() - startTime;
console.log("\n✅ VAD transcription completed!");
console.log(`⏱️ Processing time: ${vadDuration}ms`);
console.log("\n📝 VAD transcription result:");
console.log(vadResult);
// Compare with traditional approach
console.log("\n🔄 Running traditional transcription for comparison...");
const traditionalStartTime = Date.now();
const traditionalResult = await whisperAsync(traditionalParams);
const traditionalDuration = Date.now() - traditionalStartTime;
console.log("\n✅ Traditional transcription completed!");
console.log(`⏱️ Processing time: ${traditionalDuration}ms`);
console.log("\n📝 Traditional transcription result:");
console.log(traditionalResult);
// Performance comparison
console.log("\n📊 Performance Comparison:");
console.log(`VAD: ${vadDuration}ms`);
console.log(`Traditional: ${traditionalDuration}ms`);
const speedup = traditionalDuration / vadDuration;
if (speedup > 1) {
console.log(`🚀 VAD is ${speedup.toFixed(2)}x faster!`);
} else {
console.log(` Traditional approach was ${(1/speedup).toFixed(2)}x faster in this case.`);
}
} catch (error) {
console.error("❌ Error during transcription:", error);
}
}
// Run the example
if (require.main === module) {
runVADExample();
}
module.exports = {
runVADExample,
vadParams,
traditionalParams
};

View File

@ -2,7 +2,7 @@
Benchmark the performance of whisper.cpp in the browser using WebAssembly
Link: https://ggml.ai/whisper.cpp/bench.wasm/
Link: https://ggerganov.github.io/whisper.cpp/bench.wasm
Terminal version: [examples/bench](/examples/bench)
@ -32,16 +32,6 @@ cp bin/libbench.js /path/to/html/
cp bin/libbench.worker.js /path/to/html/
```
> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON`
> which means that that a separate .wasm file will not be generated. Instead, the
> WASM module is embedded in the main JS file as a base64 encoded string. To
> generate a separate .wasm file, you need to disable this option by passing
> `-DWHISPER_WASM_SINGLE_FILE=OFF`:
> ```console
> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF
> ```
> This will generate a `libbench.wasm` file in the build/bin directory.
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
> longer generated and the worker is embedded in the main JS file. So the worker
> file will not be geneated for versions later than `3.1.58`.

View File

@ -191,15 +191,15 @@
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin',
'base.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin',
'small.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin',
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
'tiny-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin',
'small-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q5_1.bin',
'medium-en-q5_0':'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q5_0.bin',
'large-q5_0': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-q5_0.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
};
let sizes = {

View File

@ -66,12 +66,13 @@ static int whisper_bench_full(const whisper_params & params) {
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
}
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 2;
@ -155,8 +156,6 @@ static int whisper_bench_full(const whisper_params & params) {
}
int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {

View File

@ -70,7 +70,6 @@ struct whisper_params {
bool no_prints = false;
bool print_special = false;
bool print_colors = false;
bool print_confidence= false;
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
@ -180,7 +179,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if ( arg == "--print-confidence"){ params.print_confidence= true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
@ -202,7 +200,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if ( arg == "--vad") { params.vad = true; }
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); }
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-vsd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); }
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); }
@ -259,7 +257,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
@ -389,26 +386,6 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else if (params.print_confidence) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
int style_idx = 2; // High confidence - dim
if (p < 0.33) {
style_idx = 0; // Low confidence - inverse (highlighted)
} else if (p < 0.66) {
style_idx = 1; // Medium confidence - underlined
}
printf("%s%s%s%s", speaker.c_str(), k_styles[style_idx].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
@ -909,8 +886,6 @@ static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
int main(int argc, char ** argv) {
ggml_backend_load_all();
#if defined(_WIN32)
// Set the console output code page to UTF-8, while command line arguments
// are still encoded in the system's code page. In this way, we can print
@ -990,6 +965,7 @@ int main(int argc, char ** argv) {
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
@ -1139,8 +1115,6 @@ int main(int argc, char ** argv) {
if (params.print_colors) {
fprintf(stderr, "%s: color scheme: red (low confidence), yellow (medium), green (high confidence)\n", __func__);
} else if (params.print_confidence) {
fprintf(stderr, "%s: confidence: highlighted (low confidence), underlined (medium), dim (high confidence)\n", __func__);
}
fprintf(stderr, "\n");
}

View File

@ -3,7 +3,7 @@
This is a basic Voice Assistant example that accepts voice commands from the microphone.
It runs in fully in the browser via WebAseembly.
Online demo: https://ggml.ai/whisper.cpp/command.wasm/
Online demo: https://ggerganov.github.io/whisper.cpp/command.wasm
Terminal version: [examples/command](/examples/command)
@ -32,16 +32,6 @@ cp bin/libcommand.js /path/to/html/
cp bin/libcommand.worker.js /path/to/html/
```
> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON`
> which means that that a separate .wasm file will not be generated. Instead, the
> WASM module is embedded in the main JS file as a base64 encoded string. To
> generate a separate .wasm file, you need to disable this option by passing
> `-DWHISPER_WASM_SINGLE_FILE=OFF`:
> ```console
> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF
> ```
> This will generate a `libcommand.wasm` file in the build/bin directory.
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
> longer generated and the worker is embedded in the main JS file. So the worker
> file will not be geneated for versions later than `3.1.58`.

View File

@ -174,11 +174,11 @@
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin',
'base.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin',
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {

View File

@ -251,7 +251,7 @@ static std::vector<std::string> get_words(const std::string &txt) {
// command-list mode
// guide the transcription to match the most likely command from a provided list
static int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params, std::ofstream &fout) {
static int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__);
@ -444,16 +444,12 @@ static int process_command_list(struct whisper_context * ctx, audio_async &audio
const float prob = probs_id[0].first;
const int index = probs_id[0].second;
const char * best_command = allowed_commands[index].c_str();
fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", best_command, "\033[0m", prob,
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n");
if (fout.is_open()) {
fout << best_command << std::endl;
}
}
}
@ -466,7 +462,7 @@ static int process_command_list(struct whisper_context * ctx, audio_async &audio
// always-prompt mode
// transcribe the voice into text after valid prompt
static int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params, std::ofstream & fout) {
static int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool ask_prompt = true;
@ -532,9 +528,6 @@ static int always_prompt_transcription(struct whisper_context * ctx, audio_async
if ((sim > 0.7f) && (command.size() > 0)) {
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
if (fout.is_open()) {
fout << command << std::endl;
}
}
fprintf(stdout, "\n");
@ -549,7 +542,7 @@ static int always_prompt_transcription(struct whisper_context * ctx, audio_async
// general-purpose mode
// freely transcribe the voice into text
static int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params, std::ofstream & fout) {
static int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
@ -669,10 +662,8 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
} else {
// cut the prompt from the decoded text
const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
if (fout.is_open()) {
fout << command << std::endl;
}
}
fprintf(stdout, "\n");
@ -687,8 +678,6 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
}
int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
@ -709,10 +698,6 @@ int main(int argc, char ** argv) {
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 2;
}
// print some info about the processing
{
@ -772,22 +757,13 @@ int main(int argc, char ** argv) {
}
}
std::ofstream fout;
if (params.fname_out.length() > 0) {
fout.open(params.fname_out);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
return 1;
}
}
if (ret_val == 0) {
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params, fout);
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params, fout);
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params, fout);
ret_val = process_general_transcription(ctx, audio, params);
}
}

View File

@ -112,20 +112,13 @@ bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std:
}
if (stereo) {
std::vector<float> stereo_data = pcmf32;
pcmf32.resize(frame_count);
for (uint64_t i = 0; i < frame_count; i++) {
pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]);
}
pcmf32s.resize(2);
pcmf32s[0].resize(frame_count);
pcmf32s[1].resize(frame_count);
for (uint64_t i = 0; i < frame_count; i++) {
pcmf32s[0][i] = stereo_data[2*i];
pcmf32s[1][i] = stereo_data[2*i + 1];
}
pcmf32s.resize(2);
pcmf32s[0].resize(frame_count);
pcmf32s[1].resize(frame_count);
for (uint64_t i = 0; i < frame_count; i++) {
pcmf32s[0][i] = pcmf32[2*i];
pcmf32s[1][i] = pcmf32[2*i + 1];
}
}
ma_decoder_uninit(&decoder);

View File

@ -294,26 +294,6 @@ const std::vector<std::string> k_colors = {
set_xterm256_foreground( 78, 178, 101),
};
// ANSI formatting codes
static std::string set_inverse() {
return "\033[7m";
}
static std::string set_underline() {
return "\033[4m";
}
static std::string set_dim() {
return "\033[2m";
}
// Style scheme for different confidence levels
const std::vector<std::string> k_styles = {
set_inverse(), // Low confidence - inverse (highlighted)
set_underline(), // Medium confidence - underlined
set_dim(), // High confidence - dim
};
//
// Other utils
//

View File

@ -424,8 +424,6 @@ static void process_loop(struct whisper_context * ctx, audio_async &audio, const
}
int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;

View File

@ -1,5 +1,4 @@
#include "ggml.h"
#include "ggml-backend.h"
#include "common.h"
#include "common-ggml.h"
@ -177,8 +176,6 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
}
int main(int argc, char ** argv) {
ggml_backend_load_all();
if (argc != 4) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
ggml_print_ftypes(stderr);

View File

@ -1,6 +1,3 @@
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(TARGET whisper-server)
add_executable(${TARGET} server.cpp httplib.h)

View File

@ -23,7 +23,6 @@ options:
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-ac N, --audio-ctx N [0 ] audio context size (0 - all)
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
@ -42,28 +41,9 @@ options:
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
--port PORT, [8080 ] Port number for the server
--public PATH, [examples/server/public] Path to the public folder
--request-path PATH, [ ] Request path for all requests
--inference-path PATH, [/inference] Inference path for all requests
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
-sns, --suppress-nst [false ] suppress non-speech tokens
-nth N, --no-speech-thold N [0.60 ] no speech threshold
-nc, --no-context [false ] do not use previous audio context
-ng, --no-gpu [false ] do not use gpu
-fa, --flash-attn [false ] flash attention
Voice Activity Detection (VAD) options:
--vad [false ] enable Voice Activity Detection (VAD)
-vm FNAME, --vad-model FNAME [ ] VAD model path
-vt N, --vad-threshold N [0.50 ] VAD threshold for speech recognition
-vspd N, --vad-min-speech-duration-ms N [250 ] VAD min speech duration (0.0-1.0)
-vsd N, --vad-min-silence-duration-ms N [100 ] VAD min silence duration (to split segments)
-vmsd N, --vad-max-speech-duration-s N [FLT_MAX] VAD max speech duration (auto-split longer)
-vp N, --vad-speech-pad-ms N [30 ] VAD speech padding (extend segments)
-vo N, --vad-samples-overlap N [0.10 ] VAD samples overlap (seconds between segments)
```
> [!WARNING]
@ -87,35 +67,3 @@ curl 127.0.0.1:8080/load \
-H "Content-Type: multipart/form-data" \
-F model="<path-to-model-file>"
```
## Load testing with k6
> **Note:** Install [k6](https://k6.io/docs/get-started/installation/) before running the benchmark script.
You can benchmark the Whisper server using the provided bench.js script with [k6](https://k6.io/). This script sends concurrent multipart requests to the /inference endpoint and is fully configurable via environment variables.
**Example usage:**
```
k6 run bench.js \
--env FILE_PATH=/absolute/path/to/samples/jfk.wav \
--env BASE_URL=http://127.0.0.1:8080 \
--env ENDPOINT=/inference \
--env CONCURRENCY=4 \
--env TEMPERATURE=0.0 \
--env TEMPERATURE_INC=0.2 \
--env RESPONSE_FORMAT=json
```
**Environment variables:**
- `FILE_PATH`: Path to the audio file to send (must be absolute or relative to the k6 working directory)
- `BASE_URL`: Server base URL (default: `http://127.0.0.1:8080`)
- `ENDPOINT`: API endpoint (default: `/inference`)
- `CONCURRENCY`: Number of concurrent requests (default: 4)
- `TEMPERATURE`: Decoding temperature (default: 0.0)
- `TEMPERATURE_INC`: Temperature increment (default: 0.2)
- `RESPONSE_FORMAT`: Response format (default: `json`)
**Note:**
- The server must be running and accessible at the specified `BASE_URL` and `ENDPOINT`.
- The script is located in the same directory as this README: `bench.js`.

View File

@ -1,29 +0,0 @@
import http from 'k6/http'
import { check } from 'k6'
export let options = {
vus: parseInt(__ENV.CONCURRENCY) || 4,
iterations: parseInt(__ENV.CONCURRENCY) || 4,
}
const filePath = __ENV.FILE_PATH
const baseURL = __ENV.BASE_URL || 'http://127.0.0.1:8080'
const endpoint = __ENV.ENDPOINT || '/inference'
const temperature = __ENV.TEMPERATURE || '0.0'
const temperatureInc = __ENV.TEMPERATURE_INC || '0.2'
const responseFormat = __ENV.RESPONSE_FORMAT || 'json'
// Read the file ONCE at init time
const fileBin = open(filePath, 'b')
export default function () {
const payload = {
file: http.file(fileBin, filePath),
temperature: temperature,
temperature_inc: temperatureInc,
response_format: responseFormat,
}
const res = http.post(`${baseURL}${endpoint}`, payload)
check(res, { 'status is 200': r => r.status === 200 })
}

View File

@ -5,7 +5,6 @@
#include "httplib.h"
#include "json.hpp"
#include <cfloat>
#include <chrono>
#include <cmath>
#include <cstdio>
@ -14,23 +13,10 @@
#include <string>
#include <thread>
#include <vector>
#include <memory>
#include <csignal>
#include <atomic>
#include <functional>
#include <cstdlib>
#if defined (_WIN32)
#include <windows.h>
#endif
using namespace httplib;
using json = nlohmann::ordered_json;
enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
};
namespace {
// output formats
@ -40,20 +26,6 @@ const std::string srt_format = "srt";
const std::string vjson_format = "verbose_json";
const std::string vtt_format = "vtt";
std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
inline void signal_handler(int signal) {
if (is_terminating.test_and_set()) {
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
// this is for better developer experience, we can remove when the server is stable enough
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
exit(1);
}
shutdown_handler(signal);
}
struct server_params
{
std::string hostname = "127.0.0.1";
@ -104,7 +76,6 @@ struct whisper_params {
bool flash_attn = false;
bool suppress_nst = false;
bool no_context = false;
bool no_language_probabilities = false;
std::string language = "en";
std::string prompt = "";
@ -119,16 +90,6 @@ struct whisper_params {
std::string openvino_encode_device = "CPU";
std::string dtw = "";
// Voice Activity Detection (VAD) parameters
bool vad = false;
std::string vad_model = "";
float vad_threshold = 0.5f;
int vad_min_speech_duration_ms = 250;
int vad_min_silence_duration_ms = 100;
float vad_max_speech_duration_s = FLT_MAX;
int vad_speech_pad_ms = 30;
float vad_samples_overlap = 0.1f;
};
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) {
@ -178,20 +139,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false");
// Voice Activity Detection (VAD) parameters
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
std::string("FLT_MAX").c_str() :
std::to_string(params.vad_max_speech_duration_s).c_str());
fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
fprintf(stderr, "\n");
}
@ -239,7 +186,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; }
else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
@ -248,16 +194,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if ( arg == "--request-path") { sparams.request_path = argv[++i]; }
else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; }
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
// Voice Activity Detection (VAD)
else if ( arg == "--vad") { params.vad = true; }
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; }
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); }
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); }
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); }
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); }
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); }
else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams);
@ -574,45 +510,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
{
params.no_context = parse_str_to_bool(req.get_file_value("no_context").content);
}
if (req.has_file("vad"))
{
params.vad = parse_str_to_bool(req.get_file_value("vad").content);
}
if (req.has_file("vad_threshold"))
{
params.vad_threshold = std::stof(req.get_file_value("vad_threshold").content);
}
if (req.has_file("vad_min_speech_duration_ms"))
{
params.vad_min_speech_duration_ms = std::stof(req.get_file_value("vad_min_speech_duration_ms").content);
}
if (req.has_file("vad_min_silence_duration_ms"))
{
params.vad_min_silence_duration_ms = std::stof(req.get_file_value("vad_min_silence_duration_ms").content);
}
if (req.has_file("vad_max_speech_duration_s"))
{
params.vad_max_speech_duration_s = std::stof(req.get_file_value("vad_max_speech_duration_s").content);
}
if (req.has_file("vad_speech_pad_ms"))
{
params.vad_speech_pad_ms = std::stoi(req.get_file_value("vad_speech_pad_ms").content);
}
if (req.has_file("vad_samples_overlap"))
{
params.vad_samples_overlap = std::stof(req.get_file_value("vad_samples_overlap").content);
}
if (req.has_file("no_language_probabilities"))
{
params.no_language_probabilities = parse_str_to_bool(req.get_file_value("no_language_probabilities").content);
}
}
} // namespace
int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params;
server_params sparams;
@ -681,19 +583,13 @@ int main(int argc, char ** argv) {
if (params.dtw == "large.v3") {
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
}
if (params.dtw == "large.v3.turbo") {
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3_TURBO;
}
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
return 3;
}
}
std::unique_ptr<httplib::Server> svr = std::make_unique<httplib::Server>();
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
@ -703,10 +599,9 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
state.store(SERVER_STATE_READY);
svr->set_default_headers({{"Server", "whisper.cpp"},
Server svr;
svr.set_default_headers({{"Server", "whisper.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type, authorization"}});
@ -785,15 +680,15 @@ int main(int argc, char ** argv) {
whisper_params default_params = params;
// this is only called if no index.html is found in the public --path
svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){
svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
res.set_content(default_content, "text/html");
return false;
});
svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
});
svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
// acquire whisper model mutex lock
std::lock_guard<std::mutex> lock(whisper_mutex);
@ -931,16 +826,6 @@ int main(int argc, char ** argv) {
wparams.suppress_nst = params.suppress_nst;
wparams.vad = params.vad;
wparams.vad_model_path = params.vad_model.c_str();
wparams.vad_params.threshold = params.vad_threshold;
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
@ -1031,25 +916,23 @@ int main(int argc, char ** argv) {
} else if (params.response_format == vjson_format) {
/* try to match openai/whisper's Python format */
std::string results = output_str(ctx, params, pcmf32s);
// Get language probabilities
std::vector<float> lang_probs(whisper_lang_max_id() + 1, 0.0f);
const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data());
json jres = json{
{"task", params.translate ? "translate" : "transcribe"},
{"language", whisper_lang_str_full(whisper_full_lang_id(ctx))},
{"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE},
{"text", results},
{"segments", json::array()}
{"segments", json::array()},
{"detected_language", whisper_lang_str_full(detected_lang_id)},
{"detected_language_probability", lang_probs[detected_lang_id]},
{"language_probabilities", json::object()}
};
// Only compute language probabilities if requested (expensive operation)
if (!params.no_language_probabilities) {
std::vector<float> lang_probs(whisper_lang_max_id() + 1, 0.0f);
const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data());
jres["detected_language"] = whisper_lang_str_full(detected_lang_id);
jres["detected_language_probability"] = lang_probs[detected_lang_id];
jres["language_probabilities"] = json::object();
// Add all language probabilities
for (int i = 0; i <= whisper_lang_max_id(); ++i) {
if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities
jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i];
}
// Add all language probabilities
for (int i = 0; i <= whisper_lang_max_id(); ++i) {
if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities
jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i];
}
}
const int n_segments = whisper_full_n_segments(ctx);
@ -1111,9 +994,8 @@ int main(int argc, char ** argv) {
// reset params to their defaults
params = default_params;
});
svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
std::lock_guard<std::mutex> lock(whisper_mutex);
state.store(SERVER_STATE_LOADING_MODEL);
if (!req.has_file("model"))
{
fprintf(stderr, "error: no 'model' field in the request\n");
@ -1145,25 +1027,18 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
state.store(SERVER_STATE_READY);
const std::string success = "Load was successful!";
res.set_content(success, "application/text");
// check if the model is in the file system
});
svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){
server_state current_state = state.load();
if (current_state == SERVER_STATE_READY) {
const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json");
} else {
res.set_content("{\"status\":\"loading model\"}", "application/json");
res.status = 503;
}
svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json");
});
svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try {
@ -1177,7 +1052,7 @@ int main(int argc, char ** argv) {
res.status = 500;
});
svr->set_error_handler([](const Request &req, Response &res) {
svr.set_error_handler([](const Request &req, Response &res) {
if (res.status == 400) {
res.set_content("Invalid request", "text/plain");
} else if (res.status != 500) {
@ -1187,10 +1062,10 @@ int main(int argc, char ** argv) {
});
// set timeouts and change hostname and port
svr->set_read_timeout(sparams.read_timeout);
svr->set_write_timeout(sparams.write_timeout);
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
if (!svr->bind_to_port(sparams.hostname, sparams.port))
if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
sparams.hostname.c_str(), sparams.port);
@ -1198,50 +1073,18 @@ int main(int argc, char ** argv) {
}
// Set the base directory for serving static files
svr->set_base_dir(sparams.public_path);
svr.set_base_dir(sparams.public_path);
// to make it ctrl+clickable:
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
shutdown_handler = [&](int signal) {
printf("\nCaught signal %d, shutting down gracefully...\n", signal);
if (svr) {
svr->stop();
}
};
if (!svr.listen_after_bind())
{
return 1;
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
// clean up function, to be called before exit
auto clean_up = [&]() {
whisper_print_timings(ctx);
whisper_free(ctx);
};
std::thread t([&] {
if (!svr->listen_after_bind()) {
fprintf(stderr, "error: server listen failed\n");
}
});
svr->wait_until_ready();
t.join();
clean_up();
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -2,7 +2,7 @@
Real-time transcription in the browser using WebAssembly
Online demo: https://ggml.ai/whisper.cpp/stream.wasm/
Online demo: https://whisper.ggerganov.com/stream/
## Build instructions
@ -30,16 +30,6 @@ cp bin/libstream.js /path/to/html/
cp bin/libstream.worker.js /path/to/html/
```
> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON`
> which means that that a separate .wasm file will not be generated. Instead, the
> WASM module is embedded in the main JS file as a base64 encoded string. To
> generate a separate .wasm file, you need to disable this option by passing
> `-DWHISPER_WASM_SINGLE_FILE=OFF`:
> ```console
> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF
> ```
> This will generate a `libstream.wasm` file in the build/bin directory.
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
> longer generated and the worker is embedded in the main JS file. So the worker
> file will not be geneated for versions later than `3.1.58`.

View File

@ -31,11 +31,10 @@ void stream_set_status(const std::string & status) {
g_status = status;
}
void stream_main(size_t index, const std::string & lang) {
void stream_main(size_t index) {
stream_set_status("loading data ...");
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
bool is_multilingual = whisper_is_multilingual(g_contexts[index]);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
@ -53,7 +52,7 @@ void stream_main(size_t index, const std::string & lang) {
// disable temperature fallback
wparams.temperature_inc = -1.0f;
wparams.language = is_multilingual ? lang.c_str() : "en";
wparams.language = "en";
printf("stream: using %d threads\n", wparams.n_threads);
@ -128,8 +127,9 @@ void stream_main(size_t index, const std::string & lang) {
g_contexts[index] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model, const std::string & lang) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
@ -138,8 +138,8 @@ EMSCRIPTEN_BINDINGS(stream) {
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i, lang]() {
stream_main(i, lang);
g_worker = std::thread([i]() {
stream_main(i);
});
return i + 1;

View File

@ -55,7 +55,6 @@
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
@ -67,77 +66,6 @@
-->
</div>
<table>
<tr>
<td>
Language:
<select id="language" name="language">
<option value="en">English</option>
<option value="ar">Arabic</option>
<option value="hy">Armenian</option>
<option value="az">Azerbaijani</option>
<option value="eu">Basque</option>
<option value="be">Belarusian</option>
<option value="bn">Bengali</option>
<option value="bg">Bulgarian</option>
<option value="ca">Catalan</option>
<option value="zh">Chinese</option>
<option value="hr">Croatian</option>
<option value="cs">Czech</option>
<option value="da">Danish</option>
<option value="nl">Dutch</option>
<option value="en">English</option>
<option value="et">Estonian</option>
<option value="tl">Filipino</option>
<option value="fi">Finnish</option>
<option value="fr">French</option>
<option value="gl">Galician</option>
<option value="ka">Georgian</option>
<option value="de">German</option>
<option value="el">Greek</option>
<option value="gu">Gujarati</option>
<option value="iw">Hebrew</option>
<option value="hi">Hindi</option>
<option value="hu">Hungarian</option>
<option value="is">Icelandic</option>
<option value="id">Indonesian</option>
<option value="ga">Irish</option>
<option value="it">Italian</option>
<option value="ja">Japanese</option>
<option value="kn">Kannada</option>
<option value="ko">Korean</option>
<option value="la">Latin</option>
<option value="lv">Latvian</option>
<option value="lt">Lithuanian</option>
<option value="mk">Macedonian</option>
<option value="ms">Malay</option>
<option value="mt">Maltese</option>
<option value="no">Norwegian</option>
<option value="fa">Persian</option>
<option value="pl">Polish</option>
<option value="pt">Portuguese</option>
<option value="ro">Romanian</option>
<option value="ru">Russian</option>
<option value="sr">Serbian</option>
<option value="sk">Slovak</option>
<option value="sl">Slovenian</option>
<option value="es">Spanish</option>
<option value="sw">Swahili</option>
<option value="sv">Swedish</option>
<option value="ta">Tamil</option>
<option value="te">Telugu</option>
<option value="th">Thai</option>
<option value="tr">Turkish</option>
<option value="uk">Ukrainian</option>
<option value="ur">Urdu</option>
<option value="vi">Vietnamese</option>
<option value="cy">Welsh</option>
<option value="yi">Yiddish</option>
</select>
</td>
</tr>
</table>
<br>
<div id="input">
@ -246,18 +174,16 @@
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin',
'base.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin',
'base' : 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin',
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'base': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
@ -271,7 +197,6 @@
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-base').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
@ -287,7 +212,6 @@
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
@ -444,7 +368,7 @@
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin', document.getElementById('language').value);
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);

View File

@ -116,8 +116,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
}
int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
@ -163,10 +161,6 @@ int main(int argc, char ** argv) {
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 2;
}
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old;

View File

@ -16,10 +16,7 @@ if (WHISPER_SDL2)
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-memory-recurrent.cpp
llama-memory-hybrid.cpp
llama-kv-cache.cpp
llama-memory.cpp
llama-mmap.cpp
llama-model-loader.cpp

View File

@ -20,7 +20,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
@ -34,7 +33,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_PLAMO2, "plamo2" },
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
@ -43,12 +41,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_JAMBA, "jamba" },
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
@ -68,26 +62,16 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -180,7 +164,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
@ -191,10 +174,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -213,13 +192,13 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@ -263,24 +242,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_ARCEE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LLAMA4,
{
@ -487,7 +448,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
@ -532,21 +492,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_NEO_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{
@ -788,36 +733,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PLAMO2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_CODESHELL,
{
@ -977,42 +892,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3N,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@ -1047,77 +926,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_JAMBA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_FALCON_H1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_XVERSE,
{
@ -1511,26 +1319,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_EXAONE4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}
},
{
LLM_ARCH_RWKV6,
{
@ -1693,46 +1481,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_GRANITE_HYBRID,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
// mamba(2) ssm layers
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
// attention layers
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
// dense FFN
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
// moe FFN
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
// shared expert
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
@ -1802,154 +1550,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DOTS1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_ERNIE4_5,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_ERNIE4_5_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_HUNYUAN_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_SMOLLM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LFM2,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
}
},
{
LLM_ARCH_DREAM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -2034,11 +1634,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@ -2082,23 +1678,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
@ -2117,22 +1696,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
std::string LLM_KV::operator()(llm_kv kv) const {
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
if (suffix != nullptr) {
name += ".";
name += suffix;
}
return name;
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
}
std::string LLM_TN_IMPL::str() const {
@ -2171,39 +1741,3 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_JAMBA:
case LLM_ARCH_FALCON_H1:
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
return true;
default:
return false;
}
}
bool llm_arch_is_diffusion(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_DREAM:
return true;
default:
return false;
}
}

View File

@ -24,7 +24,6 @@ enum llm_arch {
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
@ -38,7 +37,6 @@ enum llm_arch {
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO,
LLM_ARCH_PLAMO2,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
@ -47,12 +45,8 @@ enum llm_arch {
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_JAMBA,
LLM_ARCH_FALCON_H1,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
@ -72,26 +66,16 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_GRANITE_HYBRID,
LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_LFM2,
LLM_ARCH_DREAM,
LLM_ARCH_UNKNOWN,
};
@ -184,7 +168,6 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE,
@ -207,13 +190,13 @@ enum llm_kv {
LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID,
@ -230,10 +213,6 @@ enum llm_kv {
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
LLM_KV_CONVNEXT_BLOCK_COUNT,
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
LLM_KV_SHORTCONV_L_CACHE,
// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
@ -284,32 +263,12 @@ enum llm_tensor {
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
LLM_TENSOR_ALTUP_PROJ, // gemma3n
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
LLM_TENSOR_LAUREL_L, // gemma3n
LLM_TENSOR_LAUREL_R, // gemma3n
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_DT_NORM,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_B_NORM,
LLM_TENSOR_SSM_C_NORM,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
@ -403,9 +362,6 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_SHORTCONV_CONV,
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
};
enum llm_tensor_layer {
@ -479,7 +435,3 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
bool llm_arch_is_diffusion(const llm_arch & arch);

File diff suppressed because it is too large Load Diff

View File

@ -2,159 +2,88 @@
#include "llama.h"
#include "llama-cparams.h"
#include <array>
#include <vector>
#include <set>
#include <bitset>
#include <memory>
#include <unordered_map>
// keep this struct lightweight
// very similar to llama_batch,
// but has more metadata about sequences
struct llama_ubatch {
bool equal_seqs() const {
return b_equal_seqs != 0;
}
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
bool equal_seqs;
// TODO: whole_seqs for embeddings?
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;
// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
// used for extracting sequence pooled embeddings
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -
struct data_t {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
std::shared_ptr<data_t> data;
llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
};
// a helper for sanitizing, fulfilling and splitting a batch
class llama_batch_allocr {
public:
llama_batch_allocr(uint32_t n_pos_per_embd);
struct llama_sbatch_seq {
int32_t n_seq_id;
// sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory,
uint32_t n_embd,
uint32_t n_seq_max,
bool output_all);
llama_seq_id * seq_id;
const llama_batch & get_batch() const;
size_t offset;
size_t length;
};
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
// sequence-length-aware batch splitting
struct llama_sbatch {
// tokens left in this batch
size_t n_tokens;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
size_t n_embd;
// min/max positions of each sequence in the current ubatch
llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const;
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// call once before splitting the batch to reset the internal state
void split_reset();
// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
// simple split, unknown number of sequence sets of unequal lengths
llama_ubatch split_simple(uint32_t n_ubatch);
const llama_batch * batch = nullptr;
// make ubatches of equal-length sequences sets
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
// buffers for the ubatch
std::vector<llama_token> ubatch_token;
std::vector<float> ubatch_embd;
std::vector<llama_pos> ubatch_pos;
std::vector<int32_t> ubatch_n_seq_id;
std::vector<llama_seq_id *> ubatch_seq_id;
std::vector<int8_t> ubatch_output;
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
// a helper method for creating a well-defined ubatch of tokens
// TODO: support embeddings if needed in the future
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
private:
void clear();
// simple split, unknown number of sequences of unequal lengths
llama_ubatch split_simple(size_t n_ubatch);
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
// make batches of equal-length sequences
llama_ubatch split_equal(size_t n_ubatch);
// for debugging, start with LLAMA_BATCH_DEBUG=2
void ubatch_print(const llama_ubatch & ubatch, int debug);
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
llama_batch batch;
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
};
// only for debugging purposes
const llama_vocab * vocab;
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
const uint32_t n_pos_per_embd;
uint32_t n_embd;
uint32_t n_seq_max;
uint32_t n_outputs;
// temporary allocate memory for the input batch if needed
struct llama_batch_allocr {
struct llama_batch batch;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
std::vector<int8_t> logits;
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl = false;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
using idx_vec_t = std::vector<int32_t>;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
// batch indices of the output
std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;
int debug;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
};

View File

@ -56,7 +56,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
{ "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
@ -65,8 +64,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@ -169,13 +166,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
} else if (tmpl_contains(LU8("<Assistant>")) && tmpl_contains(LU8("<User>")) && tmpl_contains(LU8("<end▁of▁sentence>"))) {
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
if (tmpl_contains("[|tool|]")) {
return LLM_CHAT_TEMPLATE_EXAONE_4;
}
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
// EXAONE-3.0-7.8B-Instruct
return LLM_CHAT_TEMPLATE_EXAONE_3;
} else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) {
} else if (tmpl_contains("rwkv-world")) {
return LLM_CHAT_TEMPLATE_RWKV_WORLD;
} else if (tmpl_contains("<|start_of_role|>")) {
return LLM_CHAT_TEMPLATE_GRANITE;
@ -189,12 +183,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@ -343,7 +331,7 @@ int32_t llm_chat_apply_template(
std::string role(message->role);
if (role == "system") {
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
system_prompt += trim(message->content);
system_prompt = trim(message->content);
continue;
}
// in gemma, "assistant" is "model"
@ -365,7 +353,7 @@ int32_t llm_chat_apply_template(
std::string role(message->role);
if (role == "system") {
// there is no system message support, we will merge it with user prompt
system_prompt += message->content;
system_prompt = message->content;
continue;
} else if (role == "user") {
ss << "Human: ";
@ -536,35 +524,14 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
} else if (role == "user") {
ss << "[|user|]" << trim(message->content) << "\n";
} else if (role == "assistant") {
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
} else if (role == "tool") {
ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
}
}
if (add_ass) {
ss << "[|assistant|]";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "system") {
ss << "System: " << trim(chat[i]->content) << "\n\n";
} else if (role == "user") {
ss << "User: " << trim(chat[i]->content) << "\n\n";
if (i == chat.size() - 1) {
ss << "Assistant:";
}
} else if (role == "assistant") {
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << "User: " << message->content << "\n\nAssistant:";
} else {
ss << message->content << "\n\n";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
@ -676,52 +643,6 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "Assistant:";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
// dots.llm1.inst (DOTS1)
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|system|>" << message->content << "<|endofsystem|>";
} else if (role == "user") {
ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
} else {
ss << "<|response|>" << message->content << "<|endofresponse|>";
}
}
if (add_ass) {
ss << "<|response|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
// tencent/Hunyuan-A13B-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
} else if (role == "assistant") {
ss << "<|startoftext|>" << message->content << "<|eos|>";
} else {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
// moonshotai/Kimi-K2-Instruct
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "<|im_system|>system<|im_middle|>";
} else if (role == "user") {
ss << "<|im_user|>user<|im_middle|>";
} else if (role == "assistant") {
ss << "<|im_assistant|>assistant<|im_middle|>";
} else if (role == "tool") {
ss << "<|im_system|>tool<|im_middle|>";
}
ss << message->content << "<|im_end|>";
}
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
} else {
// template not supported
return -1;

View File

@ -35,7 +35,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
LLM_CHAT_TEMPLATE_EXAONE_4,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
@ -44,9 +43,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
#pragma once
#include "llama.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-graph.h"
#include "llama-adapter.h"
@ -12,14 +13,11 @@
#include <vector>
struct llama_model;
class llama_batch_allocr;
struct llama_kv_cache;
class llama_io_read_i;
class llama_io_write_i;
struct llama_memory_i;
struct llama_memory_context_i;
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
llama_context(
@ -35,6 +33,8 @@ struct llama_context {
ggml_backend_sched_t get_sched() const;
ggml_context * get_ctx_compute() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
uint32_t n_batch() const;
@ -44,12 +44,10 @@ struct llama_context {
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
llama_memory_t get_memory() const;
llama_kv_cache * get_kv_self();
const llama_kv_cache * get_kv_self() const;
// return true of the KV cache was updated
// TODO: remove
bool kv_self_update(bool optimize);
void kv_self_defrag_sched();
void kv_self_update();
enum llama_pooling_type pooling_type() const;
@ -90,18 +88,8 @@ struct llama_context {
int32_t il_start,
int32_t il_end);
// process a single ubatch with a specific graph type
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result * process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
ggml_status & ret);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
//
// state save/load
@ -179,32 +167,29 @@ private:
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs);
void output_reorder();
int32_t output_reserve(int32_t n_outputs);
//
// graph
//
public:
uint32_t graph_max_nodes() const;
int32_t graph_max_nodes() const;
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
llm_graph_result * get_gf_res_reserve() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
private:
llm_graph_params graph_params(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const;
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
llm_graph_cb graph_get_cb() const;
@ -229,9 +214,6 @@ private:
std::unique_ptr<llama_memory_i> memory;
// TODO: temporary, until the llama_kv_self_defrag() API is removed
bool memory_force_optimize = false;
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
@ -245,25 +227,18 @@ private:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
// reuse the batch_allocr to avoid unnecessary memory allocations
std::unique_ptr<llama_batch_allocr> balloc;
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
struct swap_info {
uint32_t i0;
uint32_t i1;
};
std::vector<swap_info> output_swaps;
ggml_backend_sched_ptr sched;
ggml_backend_t backend_cpu = nullptr;
std::vector<ggml_backend_ptr> backends;
ggml_context_ptr ctx_compute;
// training
ggml_opt_context_t opt_ctx = nullptr;
@ -279,18 +254,14 @@ private:
std::vector<ggml_backend_t> backend_ptrs;
std::vector<ggml_backend_buffer_type_t> backend_buft;
llm_graph_result_ptr gf_res_prev;
llm_graph_result_ptr gf_res_reserve;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool has_evaluated_once = false;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = false;
// perf
mutable int64_t t_start_us = 0;
mutable int64_t t_load_us = 0;
@ -302,6 +273,4 @@ private:
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
mutable int32_t n_eval = 0; // number of eval calls
mutable int32_t n_reused = 0; // number of times the previous graph was reused
};

View File

@ -1,5 +1 @@
#include "llama-cparams.h"
size_t llama_max_parallel_sequences(void) {
return LLAMA_MAX_SEQ;
}

View File

@ -4,15 +4,13 @@
#include <cstdint>
#define LLAMA_MAX_SEQ 64
struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
uint32_t n_seq_max;
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
int n_threads; // number of threads to use for generation
int n_threads_batch; // number of threads to use for batch processing
float rope_freq_base;
float rope_freq_scale;
@ -33,7 +31,6 @@ struct llama_cparams {
bool no_perf;
bool warmup;
bool op_offload;
bool kv_unified;
enum llama_pooling_type pooling_type;

View File

@ -1177,18 +1177,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
grammar.awaiting_trigger = false;
// get from the first matched capturing group to the end of the string
size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
auto constrained_str = grammar.trigger_buffer.substr(start);
// get from the first match to the end of the string
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str);

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,6 @@
#pragma once
#include "llama-arch.h"
#include "llama-batch.h"
#include "llama-hparams.h"
#include "llama-adapter.h"
@ -15,14 +14,12 @@ struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
struct llama_memory_context_i;
class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
class llama_memory_i;
class llama_kv_cache_unified;
class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@ -37,8 +34,6 @@ enum llm_ffn_op_type {
LLM_FFN_RELU,
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
LLM_FFN_GEGLU,
LLM_FFN_REGLU,
};
enum llm_ffn_gate_type {
@ -69,8 +64,6 @@ struct llama_cross {
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
struct llm_graph_params;
//
// llm_graph_input
//
@ -80,19 +73,11 @@ public:
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
// return true if the resulting input tensors using the provided graph parameters would be
// the same as the previous input tensors that we have currently stored in the object
virtual bool can_reuse(const llm_graph_params & params) {
// returning false here by default will prevent from reusing the graph if the check
// for the input type has not been implemented yet
GGML_UNUSED(params);
return false;
}
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd() = default;
@ -100,24 +85,20 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const uint32_t n_pos_per_embd = 1;
const int64_t n_pos_per_embd = 1;
};
// temperature tuning, used by llama4
@ -151,7 +132,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
@ -159,8 +140,7 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
@ -168,19 +148,17 @@ public:
llm_graph_input_out_ids(
const llama_hparams & hparams,
const llama_cparams & cparams,
uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
virtual ~llm_graph_input_out_ids() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const uint32_t n_outputs;
const int32_t n_outputs;
};
class llm_graph_input_mean : public llm_graph_input_i {
@ -207,16 +185,28 @@ public:
const llama_cparams & cparams;
};
class llm_graph_input_rs : public llm_graph_input_i {
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
virtual ~llm_graph_input_rs() = default;
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_memory_recurrent_context * mctx;
const llama_kv_cache_recurrent * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_recurrent * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
@ -244,8 +234,8 @@ public:
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
@ -256,72 +246,27 @@ public:
llm_graph_input_attn_kv_unified(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx) :
const llama_kv_cache_unified * kv_self) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
kv_self(kv_self) {
}
~llm_graph_input_attn_kv_unified() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified_context * mctx;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified_iswa_context * mctx;
const llama_kv_cache_unified * kv_self;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
@ -333,34 +278,12 @@ public:
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
const llama_cross * cross = nullptr;
};
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_memory_hybrid_context * mctx;
};
//
// llm_graph_result
//
@ -371,108 +294,40 @@ public:
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
class llm_graph_result;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
struct llm_graph_params {
llm_arch arch = LLM_ARCH_UNKNOWN;
llama_hparams hparams;
llama_cparams cparams;
llama_ubatch ubatch; // note: intentionally make a copy
llm_graph_type gtype;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
uint32_t n_outputs;
llm_graph_cb cb;
llm_graph_result * res;
// return true if the "other" params would result in a graph with the same topology as with the current params
// having the same topology allows us to reuse the graph in some cases
bool allow_reuse(const llm_graph_params & other) const {
// first check the ubatch
bool can_reuse_ubatch =
ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
ubatch.n_tokens == other.ubatch.n_tokens &&
ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
ubatch.n_seqs == other.ubatch.n_seqs &&
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
(
(!ubatch.token && !other.ubatch.token) ||
(!ubatch.embd && !other.ubatch.embd)
);
if (can_reuse_ubatch && !ubatch.equal_seqs()) {
if (!ubatch.data) {
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
// therefore we cannot perform the sequence id check. normally should never happen
can_reuse_ubatch = false;
} else {
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
}
}
}
if (!can_reuse_ubatch) {
return false;
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
loras == other.loras &&
cross == other.cross &&
n_outputs == other.n_outputs;
}
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
};
class llm_graph_result {
public:
llm_graph_result(int64_t max_nodes);
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() const { return t_tokens; }
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
int64_t get_max_nodes() const;
void reset();
void set_inputs(const llama_ubatch * ubatch);
// try to update the existing graph result using the new graph parameters in order to reuse it
// this can only be done if we determine that the resulting graph using the new graph parameters
// would be identical to the existing graph. in that case, we simply have to update the memory
// contexts of the input tensors of the graph and we can reuse it for another computation
// return true if the graph was updated and can be reused
bool can_reuse(const llm_graph_params & params);
llm_graph_input_i * add_input(llm_graph_input_ptr input);
void set_params(const llm_graph_params & params);
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
// important graph nodes
ggml_tensor * t_tokens = nullptr;
@ -481,34 +336,36 @@ public:
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
ggml_context_ptr ctx_compute;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_cgraph * gf;
int64_t max_nodes;
private:
// keep a copy of the previous graph parameters
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
// note: these are updated after constructing the new graph
llm_graph_params params;
// env: LLAMA_GRAPH_RESULT_DEBUG
int debug = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
//
// llm_graph_context
//
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
struct llm_graph_params {
ggml_context * ctx;
const llm_arch arch;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
int32_t n_outputs;
const llm_graph_cb & cb;
};
struct llm_graph_context {
const llm_arch arch;
@ -521,6 +378,7 @@ struct llm_graph_context {
const int64_t n_layer;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_ctx_per_seq;
const int64_t n_head;
const int64_t n_head_kv;
const int64_t n_embd_head_k;
@ -539,31 +397,31 @@ struct llm_graph_context {
const float norm_eps;
const float norm_rms_eps;
const int64_t n_tokens;
const int64_t n_outputs;
const int32_t n_tokens;
const int32_t n_outputs;
const int32_t n_ctx_orig; // yarn
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
ggml_context * ctx0 = nullptr;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_i * memory;
const llama_cross * cross;
const llm_graph_cb & cb_func;
llm_graph_result * res;
ggml_context * ctx0 = nullptr;
ggml_cgraph * gf = nullptr;
std::unique_ptr<llm_graph_result> res;
llm_graph_context(const llm_graph_params & params);
virtual ~llm_graph_context() = default;
int64_t n_pos_per_embd() const;
void cb(ggml_tensor * cur, const char * name, int il) const;
@ -635,6 +493,8 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
@ -646,18 +506,21 @@ struct llm_graph_context {
//
ggml_tensor * build_attn_mha(
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
bool v_trans,
float kq_scale) const;
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -672,6 +535,7 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -682,25 +546,11 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@ -715,57 +565,34 @@ struct llm_graph_context {
// recurrent
//
// TODO: avoid notion of "kv"
// TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
ggml_tensor * build_rs(
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
llm_graph_input_rs * build_rs_inp() const;
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_copy_mask_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
const llama_ubatch & ubatch,
int il) const;
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const;
//
// hybrid
//
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
//
// pooling
//
void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
};
// TODO: better name
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);

View File

@ -2,22 +2,6 @@
#include "ggml.h"
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
for (uint32_t il = 0; il < n_layer; ++il) {
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
}
}
bool llama_hparams::is_swa_any() const {
for (uint32_t il = 0; il < n_layer; ++il) {
if (swa_layers[il]) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_head(uint32_t il) const {
if (il < n_layer) {
return n_head_arr[il];
@ -65,64 +49,18 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}
bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}
return false;
}
bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_r() const {
uint32_t llama_hparams::n_embd_k_s() const {
if (wkv_head_size != 0) {
// for RWKV models
return token_shift_count * n_embd;
}
if (n_shortconv_l_cache != 0) {
// for LFM2 models
return n_embd * (n_shortconv_l_cache - 1);
}
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
// Corresponds to Mamba's conv_states size
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
}
uint32_t llama_hparams::n_embd_s() const {
uint32_t llama_hparams::n_embd_v_s() const {
if (wkv_head_size != 0) {
// corresponds to RWKV's wkv_states size
return n_embd * wkv_head_size;
@ -132,17 +70,9 @@ uint32_t llama_hparams::n_embd_s() const {
return ssm_d_state * ssm_d_inner;
}
bool llama_hparams::is_recurrent(uint32_t il) const {
return recurrent_layer_arr[il];
}
uint32_t llama_hparams::n_pos_per_embd() const {
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
}
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return swa_layers[il];
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
}
GGML_ABORT("fatal error");

View File

@ -6,7 +6,7 @@
// bump if necessary
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3
enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
@ -14,12 +14,6 @@ enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
};
enum llama_swa_type {
LLAMA_SWA_TYPE_NONE = 0,
LLAMA_SWA_TYPE_STANDARD = 1,
LLAMA_SWA_TYPE_CHUNKED = 2,
};
struct llama_hparams_posnet {
uint32_t n_embd;
uint32_t n_layer;
@ -41,6 +35,8 @@ struct llama_hparams {
uint32_t n_embd_features = 0;
uint32_t n_layer;
uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0;
@ -55,8 +51,6 @@ struct llama_hparams {
struct llama_hparams_posnet posnet;
struct llama_hparams_convnext convnext;
uint32_t n_shortconv_l_cache = 0;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
@ -98,28 +92,15 @@ struct llama_hparams {
float rope_freq_scale_train;
float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
float rope_yarn_log_mul;
std::array<int, 4> rope_sections;
// Sliding Window Attention (SWA)
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
// the size of the sliding window (0 - no SWA)
uint32_t n_swa = 0;
// if swa_layers[il] == true, then layer il is SWA
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
// by default, all layers are dense
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
// for State Space Models
uint32_t ssm_d_conv = 0;
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
uint32_t ssm_n_group = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
bool ssm_dt_b_c_rms = false;
@ -135,23 +116,15 @@ struct llama_hparams {
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
bool use_kq_norm = true;
// for Classifiers
uint32_t n_cls_out = 1;
// llama4
uint32_t n_moe_layer_step = 0;
bool use_kq_norm = true;
uint32_t n_attn_chunk = 0;
// values below seems to be fixed on llama4
uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// gemma3n altup
uint32_t n_altup = 4; // altup_num_inputs
uint32_t i_altup_act = 0; // altup_active_idx
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
@ -160,23 +133,6 @@ struct llama_hparams {
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
// note that if n_pattern == 0, all layers are SWA
// if n_pattern == 1, all layers are dense
// example: n_pattern = 3
// il == 0: swa
// il == 1: swa
// il == 2: dense
// il == 3: swa
// il == 4: swa
// il == 5: dense
// il == 6: swa
// etc ...
void set_swa_pattern(uint32_t n_pattern);
// return true if one of the layers is SWA
bool is_swa_any() const;
uint32_t n_head(uint32_t il = 0) const;
uint32_t n_head_kv(uint32_t il = 0) const;
@ -191,25 +147,12 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;
// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const;
uint32_t n_embd_k_s() const;
// dimension of the recurrent state embeddings
uint32_t n_embd_s() const;
// whether or not the given layer is recurrent (for hybrid models)
bool is_recurrent(uint32_t il) const;
uint32_t n_pos_per_embd() const;
uint32_t n_embd_v_s() const;
bool is_swa(uint32_t il) const;
};

View File

@ -1,295 +0,0 @@
#include "llama-kv-cache-unified-iswa.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-model.h"
#include <algorithm>
#include <cassert>
//
// llama_kv_cache_unified_iswa
//
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) {
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
size_swa = size_base;
}
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
}
void llama_kv_cache_unified_iswa::clear(bool data) {
kv_base->clear(data);
kv_swa ->clear(data);
}
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool res = true;
res = res & kv_base->seq_rm(seq_id, p0, p1);
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
return res;
}
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
kv_base->seq_keep(seq_id);
kv_swa ->seq_keep(seq_id);
}
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_base->seq_add(seq_id, p0, p1, shift);
kv_swa ->seq_add(seq_id, p0, p1, shift);
}
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
kv_base->seq_div(seq_id, p0, p1, d);
kv_swa ->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
return kv_swa->seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
do {
if (!unified) {
// requires equal splits, so we skip the simple split
break;
}
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_simple(n_ubatch);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}
auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
// if it fails, try equal split
do {
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch, !unified);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}
auto sinfos_swa = kv_swa->prepare(ubatches);
if (sinfos_swa.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
}
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
kv_base->state_write(io, seq_id);
kv_swa ->state_write(io, seq_id);
}
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
kv_base->state_read(io, seq_id);
kv_swa ->state_read(io, seq_id);
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
return kv_base.get();
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
return kv_swa.get();
}
//
// llama_kv_cache_unified_iswa_context
//
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) :
ctx_base(kv->get_base()->init_full()),
ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize) :
ctx_base(kv->get_base()->init_update(lctx, optimize)),
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
bool llama_kv_cache_unified_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_base->next();
ctx_swa ->next();
if (++i_next >= ubatches.size()) {
return false;
}
return true;
}
bool llama_kv_cache_unified_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
res = res & ctx_base->apply();
res = res & ctx_swa ->apply();
return res;
}
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
}
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
}

View File

@ -1,133 +0,0 @@
#pragma once
#include "llama-kv-cache-unified.h"
#include <vector>
//
// llama_kv_cache_unified_iswa
//
// utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_memory_i {
public:
llama_kv_cache_unified_iswa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad);
~llama_kv_cache_unified_iswa() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
//
// llama_kv_cache_unified_iswa specific API
//
llama_kv_cache_unified * get_base() const;
llama_kv_cache_unified * get_swa () const;
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
// used for errors
llama_kv_cache_unified_iswa_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv);
// used to create an update context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize);
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_context();
//
// llama_memory_context_i
//
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_iswa_context specific API
//
const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_context * get_swa() const;
private:
//llama_kv_cache_unified_iswa * kv;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_base;
const llama_memory_context_ptr ctx_swa;
const llama_memory_status status;
};

File diff suppressed because it is too large Load Diff

View File

@ -1,399 +0,0 @@
#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_model;
struct llama_context;
//
// llama_kv_cache_unified
//
class llama_kv_cache_unified : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
struct defrag_info {
bool empty() const {
return ids.empty();
}
// contains information about which cell moves where:
// - cell i moves to ids[i]
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
std::vector<uint32_t> ids;
};
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
return ssrc.empty();
}
std::vector<uint32_t> ssrc;
std::vector<uint32_t> sdst;
};
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
struct slot_info {
// data for ggml_set_rows
using idx_vec_t = std::vector<uint32_t>;
// number of streams: ns = s1 - s0 + 1
llama_seq_id s0;
llama_seq_id s1;
std::vector<llama_seq_id> strm; // [ns]
std::vector<idx_vec_t> idxs; // [ns]
uint32_t head() const {
GGML_ASSERT(idxs.size() == 1);
GGML_ASSERT(!idxs[0].empty());
return idxs[0][0];
}
void resize(size_t n) {
strm.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == strm.size());
GGML_ASSERT(!idxs.empty());
return idxs[0].size();
}
size_t n_stream() const {
return strm.size();
}
bool empty() const {
return idxs.empty();
}
void clear() {
idxs.clear();
}
};
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache_unified(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
~llama_kv_cache_unified() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
//
// llama_kv_cache_unified specific API
//
uint32_t get_size() const;
uint32_t get_n_stream() const;
bool get_has_shift() const;
//
// graph_build API
//
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
//
// preparation API
//
// find places for the provided ubatches in the cache, returns the slot infos
// return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous
// return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
//
// input API
//
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
const llama_model & model;
const llama_hparams & hparams;
struct kv_layer {
// layer index in the model
// note: can be different from the layer index in the KV cache
uint32_t il;
ggml_tensor * k;
ggml_tensor * v;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
};
bool v_trans = true; // the value tensor is transposed
const uint32_t n_seq_max = 1;
const uint32_t n_stream = 1;
// required padding
const uint32_t n_pad = 1;
// SWA
const uint32_t n_swa = 0;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
// env: LLAMA_SET_ROWS (temporary)
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
bool supports_set_rows = false;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells_unified> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
// pending stream copies that will be applied during the next update
stream_copy_info sc_info;
std::vector<kv_layer> layers;
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// return non-empty vector if cells have been moved
defrag_info defrag_prepare(int32_t n_max_nodes) const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
ggml_cgraph * build_graph_shift(
llm_graph_result * res,
llama_context * lctx) const;
ggml_cgraph * build_graph_defrag(
llm_graph_result * res,
llama_context * lctx,
const defrag_info & dinfo) const;
struct cell_ranges_t {
uint32_t strm;
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
};
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
};
class llama_kv_cache_unified_context : public llama_memory_context_i {
public:
// some shorthands
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
using defrag_info = llama_kv_cache_unified::defrag_info;
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
// used for errors
llama_kv_cache_unified_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv);
// used to create an update context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo,
stream_copy_info sc_info);
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_context();
//
// llama_memory_context_i
//
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_context specific API
//
uint32_t get_n_kv() const;
// TODO: temporary
bool get_supports_set_rows() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
llama_memory_status status;
llama_kv_cache_unified * kv;
llama_context * lctx;
//
// update context
//
bool do_shift = false;
defrag_info dinfo;
stream_copy_info sc_info;
//
// batch processing context
//
// the index of the cur ubatch to process
size_t i_cur = 0;
slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches;
//
// data needed for building the compute graph for the current ubatch:
//
// a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv;
};

File diff suppressed because it is too large Load Diff

View File

@ -2,36 +2,57 @@
#include "llama.h"
#include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include <set>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_sbatch;
struct llama_model;
struct llama_context;
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// call if batch processing fails - restores the cache state
virtual void restore() = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// call after successful batch processing - clears any pending state
virtual void commit() = 0;
// process any pending defrag/shift/etc. operations
// optionally call once before processing a new batch
// return true if any operations were performed
virtual bool update(llama_context & lctx) = 0;
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
// TODO: change to
// llama_memory_state_ptr init_defrag(float thold) = 0;
//
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual void set_full() = 0;
//
// batch processing
//
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;
// getters
virtual bool get_can_shift() const = 0;
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual llama_pos get_pos_max() const = 0;
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
@ -42,3 +63,343 @@ struct llama_kv_cache : public llama_memory_i {
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
//
// llama_kv_cache_guard
//
struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
~llama_kv_cache_guard() {
kv->restore();
}
void commit() {
kv->commit();
}
private:
llama_kv_cache * kv;
};
//
// llama_kv_cache_unified
//
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
static uint32_t get_padding(const llama_cparams & cparams);
llama_kv_cache_unified(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding);
~llama_kv_cache_unified() = default;
//
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & ctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
const llama_model & model;
const llama_hparams & hparams;
bool has_shift = false;
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// required padding
uint32_t padding = 1;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// defrag
struct {
std::vector<uint32_t> ids;
} defrag_info;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_graph_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//
// llama_kv_cache_recurrent
//
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size);
~llama_kv_cache_recurrent() = default;
//
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
int32_t s_copy(int i) const;
float s_mask(int i) const;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
//const llama_model & model;
const llama_hparams & hparams;
// commit/restore cache
// TODO: rework for recurrent cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//
// kv cache view
//
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);

View File

@ -1,491 +0,0 @@
#pragma once
#include "llama.h"
#include "llama-cparams.h"
#include <bitset>
#include <cassert>
#include <vector>
#include <set>
#include <map>
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells_unified {
public:
void reset() {
for (uint32_t i = 0; i < pos.size(); ++i) {
pos[i] = -1;
shift[i] = 0;
seq[i].reset();
}
has_shift = false;
used.clear();
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
seq_pos[s].clear();
}
}
void reset_shift() {
has_shift = false;
for (uint32_t i = 0; i < shift.size(); ++i) {
shift[i] = 0;
}
}
uint32_t size() const {
return pos.size();
}
void resize(uint32_t n) {
pos.resize(n);
shift.resize(n);
seq.resize(n);
reset();
}
bool is_empty(uint32_t i) const {
assert(i < pos.size());
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
return pos[i] == -1;
}
uint32_t get_used() const {
return used.size();
}
// the index of the first cell that is used
// return 0 if no cells are used
uint32_t used_min() const {
return used.empty() ? 0 : *used.begin();
}
// the index of the last cell that is used + 1
// return 0 if no cells are used
uint32_t used_max_p1() const {
return used.empty() ? 0 : *used.rbegin() + 1;
}
bool get_has_shift() const {
return has_shift;
}
// move cell isrc to idst (used during defrag)
void mv(uint32_t isrc, uint32_t idst) {
assert(isrc < pos.size());
assert(idst < pos.size());
assert(pos[idst] == -1);
assert(pos[isrc] != -1);
pos [idst] = pos [isrc];
shift[idst] = shift[isrc];
seq [idst] = seq [isrc];
pos [isrc] = -1;
shift[isrc] = 0;
seq [isrc].reset();
used.erase (isrc);
used.insert(idst);
}
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
assert(i + n <= pos.size());
llama_kv_cells_unified res;
res.resize(n);
for (uint32_t j = 0; j < n; ++j) {
const auto idx = i + j;
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
res.resize(idxs.size());
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
void set(uint32_t i, const llama_kv_cells_unified & other) {
assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = i + j;
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(i + j);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(i + j);
}
if (pos[idx] != -1) {
seq_pos_rm(i + j);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(i + j);
}
assert(shift[idx] == 0);
}
}
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = idxs[j];
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(idx);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(idx);
}
if (pos[idx] != -1) {
seq_pos_rm(idx);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(idx);
}
assert(shift[idx] == 0);
}
}
// clear a non-empty cell
void rm(uint32_t i) {
assert(i < pos.size());
assert(pos[i] != -1);
seq_pos_rm(i);
seq[i].reset();
pos[i] = -1;
shift[i] = 0;
used.erase(i);
}
// note: call only if the cell has seq_id
// return true if the cell becomes empty
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
assert(i < pos.size());
assert(seq[i].test(seq_id));
assert(pos[i] != -1);
assert(seq_id >= 0);
seq[i].reset(seq_id);
seq_pos_dec(seq_id, pos[i]);
if (seq[i].none()) {
pos[i] = -1;
shift[i] = 0;
used.erase(i);
return true;
}
return false;
}
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
assert(i < pos.size());
if (seq[i].test(seq_id)) {
seq_pos_rm(i);
seq[i].reset();
seq[i].set(seq_id);
seq_pos_inc(seq_id, pos[i]);
return false;
}
if (seq[i].any()) {
seq_pos_rm(i);
seq[i].reset();
pos[i] = -1;
shift[i] = 0;
used.erase(i);
return true;
}
assert(pos[i] == -1);
return false;
}
// number of different sequences in the cell
int seq_count(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);
return seq[i].count();
}
// check if the cell contains seq_id
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
assert(i < pos.size());
assert(seq_id >= 0);
return seq[i].test(seq_id);
}
// note: call only if the cell is not empty and the seq_id is not in the cell
void seq_add(uint32_t i, llama_seq_id seq_id) {
assert(i < pos.size());
assert(pos[i] != -1);
assert(!seq[i].test(seq_id));
seq[i].set(seq_id);
seq_pos_inc(seq_id, pos[i]);
}
// return the sequence id of this cell
// note: call only for cells with exactly one sequence
llama_seq_id seq_get(uint32_t i) const {
assert(seq[i].count() == 1);
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
return s;
}
}
return -1;
}
// the minimum position of sequence seq_id currently present in any of the cells
// return -1 if the sequence is not present
llama_pos seq_pos_min(llama_seq_id seq_id) const {
assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_SEQ);
if (seq_pos[seq_id].empty()) {
return -1;
}
assert(seq_pos[seq_id].begin()->second > 0);
return seq_pos[seq_id].begin()->first;
}
// the maximum position of sequence seq_id currently present in any of the cells
// return -1 if the sequence is not present
llama_pos seq_pos_max(llama_seq_id seq_id) const {
assert(seq_id >= 0);
assert(seq_id < LLAMA_MAX_SEQ);
if (seq_pos[seq_id].empty()) {
return -1;
}
assert(seq_pos[seq_id].rbegin()->second > 0);
return seq_pos[seq_id].rbegin()->first;
}
// note: call only if the cell is not empty
llama_pos pos_get(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);
return pos[i];
}
// note: call only if the cell is not empty
llama_pos get_shift(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);
return shift[i];
}
// check if a cell is not empty and its position is within [p0, p1)
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
assert(i < pos.size());
return pos[i] >= p0 && pos[i] < p1;
}
// set the position of an empty cell
// does not modify "has_shift"
// note: call only if the cell is empty
void pos_set(uint32_t i, llama_pos p) {
assert(i < pos.size());
assert(pos[i] == -1);
assert(seq[i].none());
pos[i] = p;
used.insert(i);
}
// pos[i] = pos[i] + d
// sets "has_shift" to true
// note: call only if the cell is not empty
bool pos_add(uint32_t i, llama_pos d) {
assert(i < pos.size());
assert(pos[i] != -1);
seq_pos_rm(i);
pos[i] += d;
shift[i] += d;
has_shift = true;
if (pos[i] < 0) {
seq[i].reset();
pos[i] = -1;
shift[i] = 0;
used.erase(i);
return true;
}
seq_pos_add(i);
return false;
}
// pos[i] = pos[i] / d
// sets "has_shift" to true
// note: call only if the cell is not empty
void pos_div(uint32_t i, int d) {
assert(i < pos.size());
assert(pos[i] != -1);
const llama_pos p_old = pos[i];
seq_pos_rm(i);
pos[i] /= d;
shift[i] += p_old - pos[i];
seq_pos_add(i);
has_shift = true;
}
private:
bool has_shift = false;
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
std::set<uint32_t> used;
std::vector<llama_pos> pos;
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
//
// cells.pos_add(x, shift_x);
// cells.pos_div(y, shift_y);
// ...
//
// if (cells.has_shift()) {
// for (int i = 0; i < n; ++i) {
// auto shift_i = cells.get_shift(i);
// ...
// }
// cells.reset_shift();
// }
//
std::vector<llama_pos> shift;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<seq_set_t> seq;
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
// if the position p is not present, seq_pos[s][p] is not set
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
//
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
// - during performing a cache reuse via (rm + add)
// - some vision models have input embeddings with repeating positions
//
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time:
void seq_pos_dec(llama_seq_id s, llama_pos p) {
auto it = seq_pos[s].find(p);
assert(it != seq_pos[s].end());
if (--it->second == 0) {
seq_pos[s].erase(it);
}
}
void seq_pos_inc(llama_seq_id s, llama_pos p) {
seq_pos[s][p]++;
}
// remove cell i
void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos_dec(s, pos[i]);
}
}
}
// add cell i
void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos_inc(s, pos[i]);
}
}
}
};

Some files were not shown because too many files have changed in this diff Show More