Compare commits

..

3 Commits

Author SHA1 Message Date
e400aeb770 examples : add new sources
ggml-ci
2025-04-02 15:52:29 +03:00
cb9a21b957 sync : ggml 2025-04-02 15:52:29 +03:00
dacb7caed6 cpu: move all the operators into a separate c++ file (except mul_mat) (ggml/1167)
* cpu: refactor SIMD mappings and vectorized op functions into separate files

* Fix warning for ggml_float to float

* Fix warnings

* cpu: move all the operations (except mul_mat) to a separate c++ file

* fix whitespace

* Update ggml/src/ggml-cpu/vec.h

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* Fix PR comments - use GGML_UNUSED, use cassert in ops.cpp

* Reverse the order of import for ops.h and vec.h, to match what was present in ggml-cpu.c previously

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-04-02 15:52:28 +03:00
474 changed files with 47983 additions and 95377 deletions

View File

@ -13,10 +13,11 @@ WORKDIR /app
ARG CUDA_DOCKER_ARCH=all ARG CUDA_DOCKER_ARCH=all
# Set nvcc architecture # Set nvcc architecture
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH} ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV GGML_CUDA=1
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \ apt-get install -y build-essential libsdl2-dev wget cmake git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# Ref: https://stackoverflow.com/a/53464012 # Ref: https://stackoverflow.com/a/53464012
@ -24,14 +25,7 @@ ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY .. . COPY .. .
# Enable cuBLAS RUN make base.en
RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1"
RUN find /app/build -name "*.o" -delete && \
find /app/build -name "*.a" -delete && \
rm -rf /app/build/CMakeFiles && \
rm -rf /app/build/cmake_install.cmake && \
rm -rf /app/build/_deps
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
ENV CUDA_MAIN_VERSION=12.3 ENV CUDA_MAIN_VERSION=12.3
@ -40,11 +34,7 @@ WORKDIR /app
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake git \ apt-get install -y curl ffmpeg wget cmake git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app COPY --from=build /app /app
RUN du -sh /app/*
RUN find /app -type f -size +100M
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ] ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,28 +0,0 @@
ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04
FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
# Enable SYCL
ARG GGML_SYCL_F16=OFF
RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
echo "GGML_SYCL_F16 is set" \
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \
fi && \
make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}"
FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg libsdl2-dev wget cmake git \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,40 +0,0 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG MUSA_VERSION=rc4.0.1
# Target the MUSA build image
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
# Target the MUSA runtime image
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev wget cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* /tmp/* /var/tmp/*
COPY .. .
# Enable muBLAS
RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1"
RUN find /app/build -name "*.o" -delete && \
find /app/build -name "*.a" -delete && \
rm -rf /app/build/CMakeFiles && \
rm -rf /app/build/cmake_install.cmake && \
rm -rf /app/build/_deps
FROM ${BASE_MUSA_RUN_CONTAINER} AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg wget cmake git && \
apt-get clean && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* /tmp/* /var/tmp/*
COPY --from=build /app/build/bin /app/build/bin
COPY --from=build /app/samples /app/samples
COPY --from=build /app/models /app/models
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ]

View File

@ -16,5 +16,4 @@ RUN apt-get update && \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app COPY --from=build /app /app
ENV PATH=/app/build/bin:$PATH
ENTRYPOINT [ "bash", "-c" ] ENTRYPOINT [ "bash", "-c" ]

View File

@ -1,3 +0,0 @@
build*/
.github/
.devops/

View File

@ -1,11 +1,55 @@
name: Bindings Tests (Ruby) name: Bindings Tests (Ruby)
on: on:
push: push:
branches: paths:
- master - bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/common.h
- examples/common.cpp
- examples/common-whisper.h
- examples/common-whisper.cpp
- examples/stb_vorbis.c
- examples/miniaudio.h
pull_request: pull_request:
types: [opened, synchronize, reopened] paths:
- bindings/ruby/**
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/common.h
- examples/common.cpp
- examples/common-whisper.h
- examples/common-whisper.cpp
- examples/stb_vorbis.c
- examples/miniaudio.h
jobs: jobs:
ubuntu-22: ubuntu-22:
@ -16,6 +60,6 @@ jobs:
steps: steps:
- uses: ruby/setup-ruby@v1 - uses: ruby/setup-ruby@v1
with: with:
ruby-version: '3.2' ruby-version: '3.1'
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- run: rake test - run: rake test

View File

@ -4,8 +4,6 @@ on:
push: push:
branches: branches:
- master - master
tags:
- 'v*'
pull_request: pull_request:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
workflow_dispatch: workflow_dispatch:
@ -43,7 +41,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs: outputs:
tag_name: ${{ steps.tag.outputs.name }} tag_name: ${{ steps.tag.outputs.name }}
should_release: ${{ steps.tag.outputs.should_release }}
steps: steps:
- name: Checkout with full history - name: Checkout with full history
@ -58,7 +55,6 @@ jobs:
BUILD_NUMBER=$(git rev-list --count HEAD) BUILD_NUMBER=$(git rev-list --count HEAD)
SHORT_HASH=$(git rev-parse --short=7 HEAD) SHORT_HASH=$(git rev-parse --short=7 HEAD)
CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}"
SHOULD_RELEASE="false"
echo "Raw values:" echo "Raw values:"
echo "BUILD_NUMBER: $BUILD_NUMBER" echo "BUILD_NUMBER: $BUILD_NUMBER"
@ -66,34 +62,21 @@ jobs:
echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}"
echo "CUSTOM_TAG: $CUSTOM_TAG" echo "CUSTOM_TAG: $CUSTOM_TAG"
if [[ "${{ github.ref_type }}" == "tag" ]]; then # Use custom tag if provided
echo "Using pushed tag name" if [[ -n "$CUSTOM_TAG" ]]; then
TAG_NAME="${{ github.ref_name }}"
SHOULD_RELEASE="true"
elif [[ -n "$CUSTOM_TAG" ]]; then
echo "Using custom tag" echo "Using custom tag"
TAG_NAME="${CUSTOM_TAG}" TAG_NAME="${CUSTOM_TAG}"
SHOULD_RELEASE="true"
elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then
echo "Manual release requested"
SHOULD_RELEASE="true"
TAG_NAME="b${BUILD_NUMBER}"
elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "Using master branch format" echo "Using master branch format"
TAG_NAME="b${BUILD_NUMBER}" TAG_NAME="b${BUILD_NUMBER}"
SHOULD_RELEASE="false"
else else
echo "Using non-master branch format" echo "Using non-master branch format"
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}"
SHOULD_RELEASE="false"
fi fi
echo "Final tag name: $TAG_NAME" echo "Final tag name: $TAG_NAME"
echo "Should release: $SHOULD_RELEASE"
echo "name=$TAG_NAME" >> $GITHUB_OUTPUT echo "name=$TAG_NAME" >> $GITHUB_OUTPUT
echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT
ubuntu-22: ubuntu-22:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
@ -118,10 +101,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential libsdl2-dev cmake git apt install -y build-essential libsdl2-dev cmake git
cmake -B build cmake -B build
@ -150,14 +129,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential libsdl2-dev cmake git apt install -y build-essential libsdl2-dev cmake git
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
@ -186,14 +157,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential libsdl2-dev cmake git apt install -y build-essential libsdl2-dev cmake git
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
@ -237,23 +200,23 @@ jobs:
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
# freeBSD-latest: freeBSD-latest:
# runs-on: macos-13 runs-on: macos-13
#
# steps: steps:
# - name: Clone - name: Clone
# uses: actions/checkout@v4 uses: actions/checkout@v4
#
# - name: Build - name: Build
# uses: cross-platform-actions/action@v0.27.0 uses: cross-platform-actions/action@v0.27.0
# with: with:
# operating_system: freebsd operating_system: freebsd
# version: '14.2' version: '14.2'
# run: | run: |
# sudo pkg update sudo pkg update
# sudo pkg install -y gmake sdl2 cmake git sudo pkg install -y gmake sdl2 cmake git
# cmake -B build cmake -B build
# cmake --build build --config Release cmake --build build --config Release
ubuntu-22-gcc: ubuntu-22-gcc:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
@ -279,10 +242,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential cmake libsdl2-dev git apt install -y build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
@ -313,14 +272,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential cmake libsdl2-dev git apt install -y build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
@ -351,14 +302,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential cmake libsdl2-dev git apt install -y build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
@ -392,14 +335,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt-get update
apt-get install -y ca-certificates
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
apt update apt update
apt install -y clang build-essential cmake libsdl2-dev git apt install -y clang build-essential cmake libsdl2-dev git
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
@ -430,10 +365,6 @@ jobs:
-v ${{ github.workspace }}:/workspace \ -v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' -w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e set -e
export DEBIAN_FRONTEND=noninteractive
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
apt update apt update
apt install -y build-essential cmake git apt install -y build-essential cmake git
cmake . -DCMAKE_BUILD_TYPE=Debug \ cmake . -DCMAKE_BUILD_TYPE=Debug \
@ -596,7 +527,6 @@ jobs:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }} github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-latest runs-on: windows-latest
needs: determine-tag
strategy: strategy:
matrix: matrix:
@ -631,7 +561,6 @@ jobs:
run: > run: >
cmake -S . -B ./build -A ${{ matrix.arch }} cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DBUILD_SHARED_LIBS=ON
-DWHISPER_SDL2=${{ matrix.sdl2 }} -DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build - name: Build
@ -643,48 +572,18 @@ jobs:
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload SDL2.dll - name: Upload dll
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload binaries
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.s2arc }}_SDL2.dll name: whisper-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}/SDL2.dll path: build/bin/${{ matrix.build }}
- name: Upload whisper dll
uses: actions/upload-artifact@v4
with:
name: whisper_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload ggml dll
uses: actions/upload-artifact@v4
with:
name: ggml_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml.dll
- name: Upload ggml base dll
uses: actions/upload-artifact@v4
with:
name: ggml_base_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml-base.dll
- name: Upload ggml cpu dll
uses: actions/upload-artifact@v4
with:
name: ggml_cpu_${{ matrix.arch }}.dll
path: build/bin/${{ matrix.build }}/ggml-cpu.dll
- name: Pack bin artifacts
shell: pwsh
run: |
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip"
- name: Upload binaries
if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }}
uses: actions/upload-artifact@v4
with:
name: whisper-bin-${{ matrix.arch }}.zip
path: whisper-bin-${{ matrix.arch }}.zip
windows-blas: windows-blas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
@ -697,14 +596,11 @@ jobs:
arch: [Win32, x64] arch: [Win32, x64]
blas: [ON] blas: [ON]
sdl2: [ON] sdl2: [ON]
blasver: [0.3.29]
include: include:
- arch: Win32 - arch: Win32
s2arc: x86 s2arc: x86
blasfile: x86
- arch: x64 - arch: x64
s2arc: x64 s2arc: x64
blasfile: x64_64
- sdl2: ON - sdl2: ON
s2ver: 2.28.5 s2ver: 2.28.5
@ -725,8 +621,7 @@ jobs:
- name: Install OpenBLAS and pkgconfiglite - name: Install OpenBLAS and pkgconfiglite
if: matrix.blas == 'ON' if: matrix.blas == 'ON'
run: | run: |
Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" vcpkg install --triplet=${{ matrix.s2arc }}-windows openblas
Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}"
choco install pkgconfiglite choco install pkgconfiglite
- name: Fetch SDL2 and set SDL2_DIR - name: Fetch SDL2 and set SDL2_DIR
@ -743,8 +638,6 @@ jobs:
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DGGML_BLAS=${{ matrix.blas }} -DGGML_BLAS=${{ matrix.blas }}
-DGGML_BLAS_VENDOR=OpenBLAS -DGGML_BLAS_VENDOR=OpenBLAS
-DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib"
-DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include"
-DWHISPER_SDL2=${{ matrix.sdl2 }} -DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build - name: Build
@ -754,37 +647,30 @@ jobs:
- name: Copy openblas.dll - name: Copy openblas.dll
if: matrix.blas == 'ON' if: matrix.blas == 'ON'
run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} run: copy "C:/vcpkg/packages/openblas_${{ matrix.s2arc }}-windows/bin/openblas.dll" build/bin/${{ matrix.build }}
- name: Copy SDL2.dll - name: Copy SDL2.dll
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Pack bin artifacts
shell: pwsh
run: |
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip"
- name: Upload binaries - name: Upload binaries
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: whisper-blas-bin-${{ matrix.arch }}.zip name: whisper-blas-bin-${{ matrix.arch }}
path: whisper-blas-bin-${{ matrix.arch }}.zip path: build/bin/${{ matrix.build }}
windows-cublas: windows-cublas:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }} github.event.inputs.run_type == 'full-ci' }}
runs-on: windows-2022 runs-on: windows-2019
needs: determine-tag
strategy: strategy:
fail-fast: false
matrix: matrix:
build: [Release] build: [Release]
arch: [x64] arch: [x64]
cublas: [ON] cublas: [ON]
sdl2: [ON] sdl2: [ON]
cuda-toolkit: [12.4.0, 11.8.0] cuda-toolkit: [12.2.0, 11.8.0]
include: include:
- arch: x64 - arch: x64
sdl2: ON sdl2: ON
@ -852,7 +738,7 @@ jobs:
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration # Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables # Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
@ -860,23 +746,23 @@ jobs:
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
- name: Install Cuda Toolkit 12.4.0 - name: Install Cuda Toolkit 12.2.0
if: ${{ matrix.cuda-toolkit == '12.4.0' }} if: ${{ matrix.cuda-toolkit == '12.2.0' }}
run: | run: |
$CUDA_VERSION = ${{ matrix.cuda-toolkit }} $CUDA_VERSION = ${{ matrix.cuda-toolkit }}
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
# Components versions # Components versions
$CUDART_VER = "12.4.127" $CUDART_VER = "12.2.140"
$NVCC_VER = "12.4.131" $NVCC_VER = "12.2.140"
$NVRTC_VER = "12.4.127" $NVRTC_VER = "12.2.140"
$CUBLAS_VER = "12.4.5.8" $CUBLAS_VER = "12.2.5.6"
$NVTX_VER = "12.4.127" $NVTX_VER = "12.2.140"
$PROFILER_VER = "12.4.127" $PROFILER_VER = "12.2.140"
$VS_VER = "12.4.127" $VS_VER = "12.2.140"
$NVPROF_VER = "12.4.128" $NVPROF_VER = "12.2.142"
$CCCL_VER = "12.4.127" $CCCL_VER = "12.2.140"
# Create the directory where the CUDA Toolkit will be installed # Create the directory where the CUDA Toolkit will be installed
mkdir -p $CUDA_TOOLKIT_DIR mkdir -p $CUDA_TOOLKIT_DIR
@ -910,7 +796,7 @@ jobs:
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
# Visual Studio integration # Visual Studio integration
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
# Set environment variables # Set environment variables
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
@ -938,21 +824,14 @@ jobs:
- name: Build Project - name: Build Project
shell: cmd shell: cmd
run: | run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
cmake --version cmake --version
where cmake where cmake
if "${{ matrix.cuda-toolkit }}" == "11.8.0" (
set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR
) else (
set CUDA_FLAGS=
)
cmake -S . -B build -G "Ninja Multi-Config" ^ cmake -S . -B build -G "Ninja Multi-Config" ^
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
-DGGML_CUDA=${{ matrix.cublas }} ^ -DGGML_CUDA=${{ matrix.cublas }} ^
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^ -DWHISPER_SDL2=${{ matrix.sdl2 }} ^
-DSDL2_DIR="%SDL2_DIR%" ^ -DSDL2_DIR="%SDL2_DIR%"
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%"
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%
@ -969,17 +848,11 @@ jobs:
if: matrix.sdl2 == 'ON' if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Pack bin artifacts
shell: pwsh
run: |
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip"
- name: Upload binaries - name: Upload binaries
if: ${{ needs.determine-tag.outputs.should_release }}
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip path: build/bin/${{ matrix.build }}
emscripten: emscripten:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
@ -1052,15 +925,20 @@ jobs:
- name: Pack artifacts - name: Pack artifacts
id: pack_artifacts id: pack_artifacts
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
run: | run: |
zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework
- name: Upload artifacts - name: Upload artifacts
if: ${{ needs.determine-tag.outputs.should_release }} if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' }}
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework
android: android:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
@ -1118,93 +996,38 @@ jobs:
chmod +x ./gradlew chmod +x ./gradlew
./gradlew assembleRelease ./gradlew assembleRelease
bindings-java: # TODO: disabled because of following fail: https://github.com/ggerganov/whisper.cpp/actions/runs/9686220096/job/26735899598
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || # java:
github.event.inputs.run_type == 'full-ci' }} # needs: [ 'windows' ]
needs: ['windows'] # runs-on: windows-latest
runs-on: windows-latest # steps:
steps: # - uses: actions/checkout@v4
- uses: actions/checkout@v4 #
# - name: Install Java
- name: Install Java # uses: actions/setup-java@v4
uses: actions/setup-java@v4 # with:
with: # distribution: zulu
distribution: zulu # java-version: 20
java-version: 20 #
# - name: Download Windows lib
- name: Download Whisper Windows lib # uses: actions/download-artifact@v4
uses: actions/download-artifact@v4 # with:
with: # name: win32-x86-64_whisper.dll
name: whisper_x64.dll # path: bindings/java/build/generated/resources/main/win32-x86-64
#
- name: Download GGML Windows lib # - name: Build
uses: actions/download-artifact@v4 # run: |
with: # models\download-ggml-model.cmd tiny.en
name: ggml_x64.dll # cd bindings/java
# chmod +x ./gradlew
- name: Download GGML Base Windows lib # ./gradlew build
uses: actions/download-artifact@v4 #
with: # - name: Upload jar
name: ggml_base_x64.dll # uses: actions/upload-artifact@v4
# with:
- name: Download GGML CPU Windows lib # name: whispercpp.jar
uses: actions/download-artifact@v4 # path: bindings/java/build/libs/whispercpp-*.jar
with: #
name: ggml_cpu_x64.dll
- name: Download SDL2.dll
uses: actions/download-artifact@v4
with:
name: x64_SDL2.dll
- name: List downloaded files
shell: pwsh
run: |
Get-ChildItem -Path "." -Recurse -Filter "*.dll"
- name: Move DLL to correct location
shell: pwsh
run: |
New-Item -Path "build\bin\Release" -ItemType Directory -Force
Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force
Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory"
Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force
Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory"
Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force
Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory"
Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force
Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory"
Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force
Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory"
- name: List build release files
shell: pwsh
run: |
Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll"
- name: Build
run: |
models\download-ggml-model.cmd tiny.en models/
cd bindings/java
chmod +x ./gradlew
./gradlew build --info
- name: Pack jar artifacts
shell: pwsh
run: |
Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip"
- name: Upload jar
uses: actions/upload-artifact@v4
with:
name: whispercpp.jar.zip
path: whispercpp.jar.zip
# - name: Publish package # - name: Publish package
# if: ${{ github.ref == 'refs/heads/master' }} # if: ${{ github.ref == 'refs/heads/master' }}
# uses: gradle/gradle-build-action@v2.4.2 # uses: gradle/gradle-build-action@v2.4.2
@ -1234,16 +1057,13 @@ jobs:
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 ./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
release: release:
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: needs:
- determine-tag - determine-tag
- ios-xcode-build - ios-xcode-build
- windows
- windows-blas
- windows-cublas
steps: steps:
- name: Clone - name: Clone
@ -1277,7 +1097,6 @@ jobs:
with: with:
tag_name: ${{ needs.determine-tag.outputs.tag_name }} tag_name: ${{ needs.determine-tag.outputs.tag_name }}
prerelease: ${{ github.event.inputs.pre_release_tag != '' }} prerelease: ${{ github.event.inputs.pre_release_tag != '' }}
draft: true
- name: Upload release - name: Upload release
id: upload_release id: upload_release
@ -1304,8 +1123,7 @@ jobs:
coreml-base-en: coreml-base-en:
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') || if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
github.event.inputs.create_release == 'true' || github.event.inputs.create_release == 'true' ||
github.event.inputs.pre_release_tag != '' || github.event.inputs.pre_release_tag != '' }}
startsWith(github.ref, 'refs/tags/v') }}
runs-on: macos-latest runs-on: macos-latest
needs: determine-tag needs: determine-tag
@ -1329,23 +1147,3 @@ jobs:
source venv/bin/activate source venv/bin/activate
pip install ane_transformers openai-whisper coremltools pip install ane_transformers openai-whisper coremltools
./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}
vad:
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
github.event.inputs.run_type == 'full-ci' }}
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build
shell: bash
run: |
cmake -B build
cmake --build build --config Release
- name: Test
shell: bash
run: |
ctest -R ^test-vad$ --test-dir build --output-on-failure -VV

View File

@ -15,13 +15,12 @@ jobs:
env: env:
COMMIT_SHA: ${{ github.sha }} COMMIT_SHA: ${{ github.sha }}
strategy: strategy:
fail-fast: false
matrix: matrix:
config: config:
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" } - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
- { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } #TODO: the cuda image keeps failing - disable for now
- { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } # https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } #- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
steps: steps:
- name: Check out the repo - name: Check out the repo
@ -42,35 +41,21 @@ jobs:
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Free up disk space - name: Build and push Docker image (versioned)
run: | if: github.event_name == 'push'
sudo apt-get remove -y '^dotnet-.*' '^llvm-.*' '^mysql-.*' '^postgresql-.*' uses: docker/build-push-action@v5
sudo apt-get autoremove -y with:
sudo apt-get autoclean context: .
push: true
sudo rm -rf /usr/share/dotnet platforms: ${{ matrix.config.platform }}
sudo rm -rf /usr/local/lib/android tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
sudo rm -rf /opt/ghc file: ${{ matrix.config.dockerfile }}
sudo rm -rf /opt/hostedtoolcache/CodeQL
docker system prune -af
df -h
- name: Generate tags
id: tags
run: |
TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
if [ "${{ github.event_name }}" == "push" ]; then
TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
fi
echo "tags=$TAGS" >> $GITHUB_OUTPUT
- name: Build and push Docker image (tagged) - name: Build and push Docker image (tagged)
uses: docker/build-push-action@v5 uses: docker/build-push-action@v4
with: with:
context: . context: .
push: ${{ github.event_name == 'push' }} push: ${{ github.event_name == 'push' }}
platforms: ${{ matrix.config.platform }} platforms: ${{ matrix.config.platform }}
tags: ${{ steps.tags.outputs.tags }} tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
file: ${{ matrix.config.dockerfile }} file: ${{ matrix.config.dockerfile }}

3
.gitignore vendored
View File

@ -14,7 +14,6 @@
build/ build/
build-*/ build-*/
build_*/
# SPM # SPM
.build/ .build/
@ -50,8 +49,6 @@ extra/bench-gg.txt
models/*.mlmodel models/*.mlmodel
models/*.mlmodelc models/*.mlmodelc
models/*.mlpackage models/*.mlpackage
models/*-encoder-openvino.xml
models/*-encoder-openvino-cache/
bindings/java/.gradle/ bindings/java/.gradle/
bindings/java/.idea/ bindings/java/.idea/
.idea/ .idea/

0
.gitmodules vendored Normal file
View File

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX) project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.7.6) project("whisper.cpp" VERSION 1.7.4)
include(CheckIncludeFileCXX) include(CheckIncludeFileCXX)
set(SOVERSION 1) set(SOVERSION 1)
@ -59,6 +59,9 @@ option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
# option list # option list
# #
# general
option(WHISPER_CCACHE "whisper: use ccache if available" ON)
# debug # debug
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON) option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF) option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
@ -93,6 +96,7 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
# override ggml options # override ggml options
set(GGML_CCACHE ${WHISPER_CCACHE})
set(GGML_SANITIZE_THREAD ${WHISPER_SANITIZE_THREAD}) set(GGML_SANITIZE_THREAD ${WHISPER_SANITIZE_THREAD})
set(GGML_SANITIZE_ADDRESS ${WHISPER_SANITIZE_ADDRESS}) set(GGML_SANITIZE_ADDRESS ${WHISPER_SANITIZE_ADDRESS})
set(GGML_SANITIZE_UNDEFINED ${WHISPER_SANITIZE_UNDEFINED}) set(GGML_SANITIZE_UNDEFINED ${WHISPER_SANITIZE_UNDEFINED})
@ -117,12 +121,6 @@ whisper_option_depr(WARNING WHISPER_OPENMP GGML_OPENMP)
whisper_option_depr(WARNING WHISPER_RPC GGML_RPC) whisper_option_depr(WARNING WHISPER_RPC GGML_RPC)
whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL) whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL)
whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16) whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16)
whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE)
if (GGML_CUDA AND NOT MSVC)
#GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets
add_compile_options(-Wno-deprecated-gpu-targets)
endif()
# #
# build the library # build the library
@ -137,22 +135,6 @@ if (NOT TARGET ggml)
add_library(ggml ALIAS ggml::ggml) add_library(ggml ALIAS ggml::ggml)
else() else()
add_subdirectory(ggml) add_subdirectory(ggml)
if(WIN32)
# The following adds a _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR macro and is a workaround for
# the Windows C++ standard library which does not support constexpr mutexes.
# From the release notes://github.com/microsoft/STL/wiki/Changelog
# Disable constexpr mutex constructor on Windows
# Fixed mutex's constructor to be constexpr. #3824 #4000 #4339
# Note: Programs that aren't following the documented restrictions on binary compatibility may encounter
# null dereferences in mutex machinery. You must follow this rule:
# When you mix binaries built by different supported versions of the toolset, the Redistributable version
# must be at least as new as the latest toolset used by any app component.
# You can define _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR as an escape hatch.
#
# Specifically to whisper.cpp this would cause a crash when using the Java bindings.
# resulting in a Invalid memory access error.
target_compile_definitions(ggml-base PRIVATE _DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()
endif() endif()
# ... otherwise assume ggml is added by a parent CMakeLists.txt # ... otherwise assume ggml is added by a parent CMakeLists.txt
endif() endif()
@ -178,10 +160,6 @@ get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h)
install(TARGETS whisper LIBRARY PUBLIC_HEADER) install(TARGETS whisper LIBRARY PUBLIC_HEADER)
target_compile_definitions(whisper PRIVATE
WHISPER_VERSION="${PROJECT_VERSION}"
)
configure_package_config_file( configure_package_config_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake
@ -219,37 +197,3 @@ endif ()
if (WHISPER_BUILD_EXAMPLES) if (WHISPER_BUILD_EXAMPLES)
add_subdirectory(examples) add_subdirectory(examples)
endif() endif()
if (MSVC)
set(MSVC_WARNING_FLAGS
/wd4101 # Unreferenced local variable
/wd4005 # Macro redefinition
/wd4065 # switch statement contains 'default' but no 'case' labels
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
/wd4244 # Conversion from one type to another type, possible loss of ata
/wd4805 # Unsafe mix of type
/wd4305 # Truncation from 'type1' to 'type2' (often double to float)
/wd4996 # Function or variable may be unsafe/deprecated
)
function(disable_msvc_warnings target_name)
if(TARGET ${target_name})
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
endif()
endfunction()
if (WHISPER_BUILD_EXAMPLES)
disable_msvc_warnings(whisper)
disable_msvc_warnings(common)
disable_msvc_warnings(common-sdl)
disable_msvc_warnings(lsp)
disable_msvc_warnings(wchess-core)
disable_msvc_warnings(whisper-command)
disable_msvc_warnings(whisper-cli)
disable_msvc_warnings(whisper-server)
disable_msvc_warnings(whisper-stream)
disable_msvc_warnings(whisper-talk-llama)
disable_msvc_warnings(whisper-bench)
disable_msvc_warnings(quantize)
disable_msvc_warnings(vad-speech-segments)
endif()
endif()

View File

@ -4,7 +4,7 @@
.PHONY: build .PHONY: build
build: build:
cmake -B build $(CMAKE_ARGS) cmake -B build
cmake --build build --config Release cmake --build build --config Release
# download a few audio samples into folder "./samples": # download a few audio samples into folder "./samples":
@ -41,14 +41,14 @@ samples:
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo: tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3 large-v3-turbo:
bash ./models/download-ggml-model.sh $@ bash ./models/download-ggml-model.sh $@
cmake -B build $(CMAKE_ARGS) cmake -B build
cmake --build build --config Release cmake --build build --config Release
@echo "" @echo ""
@echo "===============================================" @echo "==============================================="
@echo "Running $@ on all samples in ./samples ..." @echo "Running $@ on all samples in ./samples ..."
@echo "===============================================" @echo "==============================================="
@echo "" @echo ""
@for f in samples/*.{flac,mp3,ogg,wav}; do \ @for f in samples/*$(.flac .mp3 .ogg .wav); do \
echo "----------------------------------------------" ; \ echo "----------------------------------------------" ; \
echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \ echo "[+] Running $@ on $$f ... (run 'ffplay $$f' to listen)" ; \
echo "----------------------------------------------" ; \ echo "----------------------------------------------" ; \

252
README.md
View File

@ -2,12 +2,15 @@
![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg) ![whisper.cpp](https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg)
[![Actions Status](https://github.com/ggml-org/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggml-org/whisper.cpp/actions) [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.7.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) > [!NOTE]
> New maintenance roadmap: https://github.com/ggerganov/whisper.cpp/discussions/2788
Stable: [v1.7.4](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.4) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -23,9 +26,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- [Efficient GPU support for NVIDIA](#nvidia-gpu-support) - [Efficient GPU support for NVIDIA](#nvidia-gpu-support)
- [OpenVINO Support](#openvino-support) - [OpenVINO Support](#openvino-support)
- [Ascend NPU Support](#ascend-npu-support) - [Ascend NPU Support](#ascend-npu-support)
- [Moore Threads GPU Support](#moore-threads-gpu-support) - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/include/whisper.h)
- [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
- [Voice Activity Detection (VAD)](#voice-activity-detection-vad)
Supported platforms: Supported platforms:
@ -33,14 +34,14 @@ Supported platforms:
- [x] [iOS](examples/whisper.objc) - [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android) - [x] [Android](examples/whisper.android)
- [x] [Java](bindings/java/README.md) - [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggml-org/whisper.cpp/issues/56#issuecomment-1350920264) - [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm) - [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168)) - [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggml-org/whisper.cpp/discussions/166) - [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
- [x] [Docker](https://github.com/ggml-org/whisper.cpp/pkgs/container/whisper.cpp) - [x] [Docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp). The entire high-level implementation of the model is contained in [whisper.h](include/whisper.h) and [whisper.cpp](src/whisper.cpp).
The rest of the code is part of the [`ggml`](https://github.com/ggml-org/ggml) machine learning library. The rest of the code is part of the [`ggml`](https://github.com/ggerganov/ggml) machine learning library.
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications. Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc) As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
@ -53,14 +54,14 @@ https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a
On Apple Silicon, the inference runs fully on the GPU via Metal: On Apple Silicon, the inference runs fully on the GPU via Metal:
https://github.com/ggml-org/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225 https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
## Quick start ## Quick start
First clone the repository: First clone the repository:
```bash ```bash
git clone https://github.com/ggml-org/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
``` ```
Navigate into the directory: Navigate into the directory:
@ -80,7 +81,7 @@ Now build the [whisper-cli](examples/cli) example and transcribe an audio file l
```bash ```bash
# build the project # build the project
cmake -B build cmake -B build
cmake --build build -j --config Release cmake --build build --config Release
# transcribe an audio file # transcribe an audio file
./build/bin/whisper-cli -f samples/jfk.wav ./build/bin/whisper-cli -f samples/jfk.wav
@ -149,9 +150,8 @@ standard cmake setup with:
```bash ```bash
# build with GGML_BLAS defined # build with GGML_BLAS defined
cmake -B build -DGGML_BLAS=1 cmake -B build -DGGML_BLAS=1
cmake --build build -j --config Release cmake --build build --config Release
./build/bin/whisper-cli [ .. etc .. ] ./build/bin/whisper-cli [ .. etc .. ]
```
## Quantization ## Quantization
@ -163,7 +163,7 @@ Here are the steps for creating and using a quantized model:
```bash ```bash
# quantize a model with Q5_0 method # quantize a model with Q5_0 method
cmake -B build cmake -B build
cmake --build build -j --config Release cmake --build build --config Release
./build/bin/quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0 ./build/bin/quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
# run the examples as usual, specifying the quantized model file # run the examples as usual, specifying the quantized model file
@ -225,7 +225,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format. The first run on a device is slow, since the ANE service compiles the Core ML model to some device-specific format.
Next runs are faster. Next runs are faster.
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggml-org/whisper.cpp/pull/566). For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## OpenVINO support ## OpenVINO support
@ -267,7 +267,7 @@ This can result in significant speedup in encoder performance. Here are the inst
- Build `whisper.cpp` with OpenVINO support: - Build `whisper.cpp` with OpenVINO support:
Download OpenVINO package from [release page](https://github.com/openvinotoolkit/openvino/releases). The recommended version to use is [2024.6.0](https://github.com/openvinotoolkit/openvino/releases/tag/2024.6.0). Ready to use Binaries of the required libraries can be found in the [OpenVino Archives](https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.6/) Download OpenVINO package from [release page](https://github.com/openvinotoolkit/openvino/releases). The recommended version to use is [2023.0.0](https://github.com/openvinotoolkit/openvino/releases/tag/2023.0.0).
After downloading & extracting package onto your development system, set up required environment by sourcing setupvars script. For example: After downloading & extracting package onto your development system, set up required environment by sourcing setupvars script. For example:
@ -310,7 +310,7 @@ This can result in significant speedup in encoder performance. Here are the inst
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run. cached for the next run.
For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggml-org/whisper.cpp/pull/1037). For more information about the OpenVINO implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support ## NVIDIA GPU support
@ -324,12 +324,6 @@ cmake -B build -DGGML_CUDA=1
cmake --build build -j --config Release cmake --build build -j --config Release
``` ```
or for newer NVIDIA GPU's (RTX 5000 series):
```
cmake -B build -DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"
cmake --build build -j --config Release
```
## Vulkan GPU support ## Vulkan GPU support
Cross-vendor solution which allows you to accelerate workload on your GPU. Cross-vendor solution which allows you to accelerate workload on your GPU.
First, make sure your graphics card driver provides support for Vulkan API. First, make sure your graphics card driver provides support for Vulkan API.
@ -383,56 +377,6 @@ Run the inference examples as usual, for example:
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag. - If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`. - If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
## Moore Threads GPU support
With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels.
First, make sure you have installed `MUSA SDK rc4.0.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=4.0.1
Now build `whisper.cpp` with MUSA support:
```
cmake -B build -DGGML_MUSA=1
cmake --build build -j --config Release
```
or specify the architecture for your Moore Threads GPU. For example, if you have a MTT S80 GPU, you can specify the architecture as follows:
```
cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21"
cmake --build build -j --config Release
```
## FFmpeg support (Linux only)
If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration.
First, you need to install required libraries:
```bash
# Debian/Ubuntu
sudo apt install libavcodec-dev libavformat-dev libavutil-dev
# RHEL/Fedora
sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-devel
```
Then you can build the project as follows:
```bash
cmake -B build -D WHISPER_FFMPEG=yes
cmake --build build
```
Run the following example to confirm it's working:
```bash
# Convert an audio file to Opus format
ffmpeg -i samples/jfk.wav jfk.opus
# Transcribe the audio file
./build/bin/whisper-cli --model models/ggml-base.en.bin --file jfk.opus
```
## Docker ## Docker
### Prerequisites ### Prerequisites
@ -444,9 +388,8 @@ ffmpeg -i samples/jfk.wav jfk.opus
We have two Docker images available for this project: We have two Docker images available for this project:
1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`) 1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`) 2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`)
### Usage ### Usage
@ -459,11 +402,11 @@ docker run -it --rm \
docker run -it --rm \ docker run -it --rm \
-v path/to/models:/models \ -v path/to/models:/models \
-v path/to/audios:/audios \ -v path/to/audios:/audios \
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav" whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
# transcribe an audio file in samples folder # transcribe an audio file in samples folder
docker run -it --rm \ docker run -it --rm \
-v path/to/models:/models \ -v path/to/models:/models \
whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav" whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
``` ```
## Installing with Conan ## Installing with Conan
@ -484,12 +427,12 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
This is a naive example of performing real-time inference on audio from your microphone. This is a naive example of performing real-time inference on audio from your microphone.
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously. The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
More info is available in [issue #10](https://github.com/ggml-org/whisper.cpp/issues/10). More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly. You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly.
```bash ```bash
cmake -B build -DWHISPER_SDL2=ON cmake -B build -DWHISPER_SDL2=ON
cmake --build build -j --config Release cmake --build build --config Release
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000 ./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
``` ```
@ -573,7 +516,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
## Speaker segmentation via tinydiarize (experimental) ## Speaker segmentation via tinydiarize (experimental)
More information about this approach is available here: https://github.com/ggml-org/whisper.cpp/pull/1058 More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
Sample usage: Sample usage:
@ -600,7 +543,7 @@ main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 pr
## Karaoke-style movie generation (experimental) ## Karaoke-style movie generation (experimental)
The [whisper-cli](examples/cli) example provides support for output of karaoke-style movies, where the The [whisper-cli](examples/cli) example provides support for output of karaoke-style movies, where the
currently pronounced word is highlighted. Use the `-owts` argument and run the generated bash script. currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed. This requires to have `ffmpeg` installed.
Here are a few _"typical"_ examples: Here are a few _"typical"_ examples:
@ -637,7 +580,7 @@ https://user-images.githubusercontent.com/1991296/199337538-b7b0c7a3-2753-4a88-a
## Video comparison of different models ## Video comparison of different models
Use the [scripts/bench-wts.sh](https://github.com/ggml-org/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format: Use the [scripts/bench-wts.sh](https://github.com/ggerganov/whisper.cpp/blob/master/scripts/bench-wts.sh) script to generate a video in the following format:
```bash ```bash
./scripts/bench-wts.sh samples/jfk.wav ./scripts/bench-wts.sh samples/jfk.wav
@ -654,7 +597,7 @@ In order to have an objective comparison of the performance of the inference acr
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
took to execute it. The results are summarized in the following Github issue: took to execute it. The results are summarized in the following Github issue:
[Benchmark results](https://github.com/ggml-org/whisper.cpp/issues/89) [Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py). Additionally a script to run whisper.cpp with different models and audio files is provided [bench.py](scripts/bench.py).
@ -681,24 +624,25 @@ You can download the converted models using the [models/download-ggml-model.sh](
or manually from here: or manually from here:
- https://huggingface.co/ggerganov/whisper.cpp - https://huggingface.co/ggerganov/whisper.cpp
- https://ggml.ggerganov.com
For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md). For more details, see the conversion script [models/convert-pt-to-ggml.py](models/convert-pt-to-ggml.py) or [models/README.md](models/README.md).
## [Bindings](https://github.com/ggml-org/whisper.cpp/discussions/categories/bindings) ## [Bindings](https://github.com/ggerganov/whisper.cpp/discussions/categories/bindings)
- [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggml-org/whisper.cpp/discussions/310) - [x] Rust: [tazz4843/whisper-rs](https://github.com/tazz4843/whisper-rs) | [#310](https://github.com/ggerganov/whisper.cpp/discussions/310)
- [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggml-org/whisper.cpp/discussions/309) - [x] JavaScript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn) - React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggml-org/whisper.cpp/discussions/312) - [x] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [x] Java: - [x] Java:
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni) - [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
- [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggml-org/whisper.cpp/discussions/507) - [x] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [x] Objective-C / Swift: [ggml-org/whisper.spm](https://github.com/ggml-org/whisper.spm) | [#313](https://github.com/ggml-org/whisper.cpp/discussions/313) - [x] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper) - [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)
- [x] .NET: | [#422](https://github.com/ggml-org/whisper.cpp/discussions/422) - [x] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net) - [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper) - [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [x] Python: | [#9](https://github.com/ggml-org/whisper.cpp/issues/9) - [x] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython) - [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
- [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp) - [AIWintermuteAI/whispercpp](https://github.com/AIWintermuteAI/whispercpp) (Updated fork of aarnphm/whispercpp)
- [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11) - [aarnphm/whispercpp](https://github.com/aarnphm/whispercpp) (Pybind11)
@ -706,118 +650,6 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
- [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper) - [x] R: [bnosac/audio.whisper](https://github.com/bnosac/audio.whisper)
- [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity) - [x] Unity: [macoron/whisper.unity](https://github.com/Macoron/whisper.unity)
## XCFramework
The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS,
and macOS. It can be used in Swift projects without the need to compile the
library from source. For example, the v1.7.5 version of the XCFramework can be
used as follows:
```swift
// swift-tools-version: 5.10
// The swift-tools-version declares the minimum version of Swift required to build this package.
import PackageDescription
let package = Package(
name: "Whisper",
targets: [
.executableTarget(
name: "Whisper",
dependencies: [
"WhisperFramework"
]),
.binaryTarget(
name: "WhisperFramework",
url: "https://github.com/ggml-org/whisper.cpp/releases/download/v1.7.5/whisper-v1.7.5-xcframework.zip",
checksum: "c7faeb328620d6012e130f3d705c51a6ea6c995605f2df50f6e1ad68c59c6c4a"
)
]
)
```
## Voice Activity Detection (VAD)
Support for Voice Activity Detection (VAD) can be enabled using the `--vad`
argument to `whisper-cli`. In addition to this option a VAD model is also
required.
The way this works is that first the audio samples are passed through
the VAD model which will detect speech segments. Using this information the
only the speech segments that are detected are extracted from the original audio
input and passed to whisper for processing. This reduces the amount of audio
data that needs to be processed by whisper and can significantly speed up the
transcription process.
The following VAD models are currently supported:
### Silero-VAD
[Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model
written in Python that is fast and accurate.
Models can be downloaded by running the following command on Linux or MacOS:
```console
$ ./models/download-vad-model.sh silero-v5.1.2
Downloading ggml model silero-v5.1.2 from 'https://huggingface.co/ggml-org/whisper-vad' ...
ggml-silero-v5.1.2.bin 100%[==============================================>] 864.35K --.-KB/s in 0.04s
Done! Model 'silero-v5.1.2' saved in '/path/models/ggml-silero-v5.1.2.bin'
You can now use it like this:
$ ./build/bin/whisper-cli -vm /path/models/ggml-silero-v5.1.2.bin --vad -f samples/jfk.wav -m models/ggml-base.en.bin
```
And the following command on Windows:
```console
> .\models\download-vad-model.cmd silero-v5.1.2
Downloading vad model silero-v5.1.2...
Done! Model silero-v5.1.2 saved in C:\Users\danie\work\ai\whisper.cpp\ggml-silero-v5.1.2.bin
You can now use it like this:
C:\path\build\bin\Release\whisper-cli.exe -vm C:\path\ggml-silero-v5.1.2.bin --vad -m models/ggml-base.en.bin -f samples\jfk.wav
```
To see a list of all available models, run the above commands without any
arguments.
This model can be also be converted manually to ggml using the following command:
```console
$ python3 -m venv venv && source venv/bin/activate
$ (venv) pip install silero-vad
$ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin
Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin
```
And it can then be used with whisper as follows:
```console
$ ./build/bin/whisper-cli \
--file ./samples/jfk.wav \
--model ./models/ggml-base.en.bin \
--vad \
--vad-model ./models/silero-v5.1.2-ggml.bin
```
### VAD Options
* --vad-threshold: Threshold probability for speech detection. A probability
for a speech segment/frame above this threshold will be considered as speech.
* --vad-min-speech-duration-ms: Minimum speech duration in milliseconds. Speech
segments shorter than this value will be discarded to filter out brief noise or
false positives.
* --vad-min-silence-duration-ms: Minimum silence duration in milliseconds. Silence
periods must be at least this long to end a speech segment. Shorter silence
periods will be ignored and included as part of the speech.
* --vad-max-speech-duration-s: Maximum speech duration in seconds. Speech segments
longer than this will be automatically split into multiple segments at silence
points exceeding 98ms to prevent excessively long segments.
* --vad-speech-pad-ms: Speech padding in milliseconds. Adds this amount of padding
before and after each detected speech segment to avoid cutting off speech edges.
* --vad-samples-overlap: Amount of audio to extend from each speech segment into
the next one, in seconds (e.g., 0.10 = 100ms overlap). This ensures speech isn't
cut off abruptly between segments when they're concatenated together.
## Examples ## Examples
There are various examples of using the library for different projects in the [examples](examples) folder. There are various examples of using the library for different projects in the [examples](examples) folder.
@ -836,13 +668,13 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp | | [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim | | [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture | | [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggml-org/whisper.cpp/issues/185) | | [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) | | [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess | | [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
## [Discussions](https://github.com/ggml-org/whisper.cpp/discussions) ## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic. If you have any kind of feedback about this project feel free to use the Discussions section and open a new topic.
You can use the [Show and tell](https://github.com/ggml-org/whisper.cpp/discussions/categories/show-and-tell) category You can use the [Show and tell](https://github.com/ggerganov/whisper.cpp/discussions/categories/show-and-tell) category
to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the to share your own projects that use `whisper.cpp`. If you have a question, make sure to check the
[Frequently asked questions (#126)](https://github.com/ggml-org/whisper.cpp/discussions/126) discussion. [Frequently asked questions (#126)](https://github.com/ggerganov/whisper.cpp/discussions/126) discussion.

View File

@ -16,13 +16,13 @@
## Background ## Background
SYCL is a higher-level programming model to improve programming productivity on various hardware acceleratorssuch as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17. SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators<EFBFBD>such as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17.
oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms. oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms.
Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs. Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs.
To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL. To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel<EFBFBD> DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL.
The whisper.cpp for SYCL is used to support Intel GPUs. The whisper.cpp for SYCL is used to support Intel GPUs.
@ -84,10 +84,10 @@ Platform #0: Intel(R) OpenCL HD Graphics
`-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49] `-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
``` ```
2. Install Intel® oneAPI Base toolkit. 2. Install Intel<EFBFBD> oneAPI Base toolkit.
a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html). a. Please follow the procedure in [Get the Intel<EFBFBD> oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
Recommend to install to default folder: **/opt/intel/oneapi**. Recommend to install to default folder: **/opt/intel/oneapi**.

View File

@ -51,7 +51,7 @@ func main() {
In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with: In order to build, you need to have the Go compiler installed. You can get it from [here](https://golang.org/dl/). Run the tests with:
```bash ```bash
git clone https://github.com/ggml-org/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/go cd whisper.cpp/bindings/go
make test make test
``` ```
@ -98,7 +98,7 @@ The API Documentation:
Getting help: Getting help:
* Follow the discussion for the go bindings [here](https://github.com/ggml-org/whisper.cpp/discussions/312) * Follow the discussion for the go bindings [here](https://github.com/ggerganov/whisper.cpp/discussions/312)
## License ## License

View File

@ -1,5 +1,5 @@
/* /*
github.com/ggml-org/whisper.cpp/bindings/go github.com/ggerganov/whisper.cpp/bindings/go
provides a speech-to-text service bindings for the Go programming language. provides a speech-to-text service bindings for the Go programming language.
*/ */
package whisper package whisper

View File

@ -23,42 +23,26 @@ import io.github.ggerganov.whispercpp.WhisperCpp;
public class Example { public class Example {
public static void main(String[] args) { public static void main(String[] args) {
WhisperCpp whisper = new WhisperCpp(); WhisperCpp whisper = new WhisperCpp();
try {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin" // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file. // or you can provide the absolute path to the model file.
whisper.initContext("../ggml-base.en.bin"); long context = whisper.initContext("base.en");
WhisperFullParams.ByValue whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); try {
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// custom configuration if required // custom configuration if required
//whisperParams.n_threads = 8; whisperParams.temperature_inc = 0f;
whisperParams.temperature = 0.0f;
whisperParams.temperature_inc = 0.2f;
//whisperParams.language = "en";
float[] samples = readAudio(); // divide each value by 32767.0f var samples = readAudio(); // divide each value by 32767.0f
List<WhisperSegment> whisperSegmentList = whisper.fullTranscribeWithTime(whisperParams, samples); whisper.fullTranscribe(whisperParams, samples);
for (WhisperSegment whisperSegment : whisperSegmentList) {
long start = whisperSegment.getStart();
long end = whisperSegment.getEnd();
String text = whisperSegment.getSentence();
System.out.println("start: "+start);
System.out.println("end: "+end);
System.out.println("text: "+text);
int segmentCount = whisper.getTextSegmentCount(context);
for (int i = 0; i < segmentCount; i++) {
String text = whisper.getTextSegment(context, i);
System.out.println(segment.getText());
} }
} catch (IOException e) {
e.printStackTrace();
} finally { } finally {
whisper.close(); whisper.freeContext(context);
} }
} }
} }
``` ```
@ -68,7 +52,7 @@ public class Example {
In order to build, you need to have the JDK 8 or higher installed. Run the tests with: In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
```bash ```bash
git clone https://github.com/ggml-org/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java cd whisper.cpp/bindings/java
./gradlew build ./gradlew build

View File

@ -27,41 +27,23 @@ sourceSets {
tasks.register('copyLibwhisperDynlib', Copy) { tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build/src' from '../../build/src'
include 'libwhisper.dylib' include 'libwhisper.dylib'
into 'build/generated/resources/main' into 'build/generated/resources/main/darwin'
} }
tasks.register('copyLibwhisperSo', Copy) { tasks.register('copyLibwhisperSo', Copy) {
from '../../build/src' from '../../build/src'
include 'libwhisper.so' include 'libwhisper.so'
into 'build/generated/resources/main' into 'build/generated/resources/main/linux-x86-64'
} }
tasks.register('copyWhisperDLL', Copy) { tasks.register('copyWhisperDll', Copy) {
from '../../build/bin/Release' from '../../build/Release'
include 'whisper.dll' include 'whisper.dll'
into 'build/generated/resources/main' into 'build/generated/resources/main/windows-x86-64'
}
tasks.register('copyGGML_BASE_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml-base.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml.dll'
into 'build/generated/resources/main'
}
tasks.register('copyGGML_CPU_DLL', Copy) {
from '../../build/bin/Release'
include 'ggml-cpu.dll'
into 'build/generated/resources/main'
} }
tasks.register('copyLibs') { tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDLL, copyGGML_BASE_DLL, copyGGML_DLL, copyGGML_CPU_DLL dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
} }
test { test {

View File

@ -168,26 +168,23 @@ public class WhisperCpp implements AutoCloseable {
return str.toString().trim(); return str.toString().trim();
} }
/** public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
* Full transcribe with time list.
*
* @param whisperParams the whisper params
* @param audioData the audio data
* @return the list
* @throws IOException the io exception
*/
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException {
if (ctx == null) { if (ctx == null) {
throw new IllegalStateException("Model not initialised"); throw new IllegalStateException("Model not initialised");
} }
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) { WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
valueParams.read();
if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio"); throw new IOException("Failed to process audio");
} }
int nSegments = lib.whisper_full_n_segments(ctx); int nSegments = lib.whisper_full_n_segments(ctx);
List<WhisperSegment> segments= new ArrayList<>(nSegments); List<WhisperSegment> segments= new ArrayList<>(nSegments);
for (int i = 0; i < nSegments; i++) { for (int i = 0; i < nSegments; i++) {
long t0 = lib.whisper_full_get_segment_t0(ctx, i); long t0 = lib.whisper_full_get_segment_t0(ctx, i);
String text = lib.whisper_full_get_segment_text(ctx, i); String text = lib.whisper_full_get_segment_text(ctx, i);

View File

@ -9,7 +9,6 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library { public interface WhisperCppJnaLibrary extends Library {
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class); WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
String whisper_print_system_info(); String whisper_print_system_info();

View File

@ -118,7 +118,7 @@ class WhisperCppTest {
float[] floats = new float[b.length / 2]; float[] floats = new float[b.length / 2];
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); //WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress)); params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE; params.print_progress = CBool.FALSE;
//params.initial_prompt = "and so my fellow Americans um, like"; //params.initial_prompt = "and so my fellow Americans um, like";

View File

@ -1,6 +1,6 @@
{ {
"name": "whisper.cpp", "name": "whisper.cpp",
"version": "1.7.6", "version": "1.7.4",
"description": "Whisper speech recognition", "description": "Whisper speech recognition",
"main": "whisper.js", "main": "whisper.js",
"scripts": { "scripts": {

View File

@ -1,9 +1,3 @@
LICENSE LICENSE
pkg/ pkg/
lib/whisper.* lib/whisper.*
ext/examples/
ext/ggml/
ext/include/
ext/scripts/
ext/src/
test/fixtures/

View File

@ -16,32 +16,6 @@ If bundler is not being used to manage dependencies, install the gem by executin
$ gem install whispercpp $ gem install whispercpp
You can pass build options for whisper.cpp, for instance:
$ bundle config build.whispercpp --enable-ggml-cuda
or,
$ gem install whispercpp -- --enable-ggml-cuda
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options, for example:
Boolean options:
* `-DGGML_BLAS=1` -> `--enable-ggml-blas`
* `-DWHISER_COREML=OFF` -> `--disable-whisper-coreml`
Argument options:
* `-DGGML_CUDA_COMPRESSION_MODE=size` -> `--ggml-cuda-compression-mode=size`
Combination:
* `-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"` -> `--enable-ggml-cuda --cmake_cuda-architectures="86"`
For boolean options like `GGML_CUDA`, the README says `-DGGML_CUDA=1`. You need strip `-D`, prepend `--enable-` for `1` or `ON` (`--disable-` for `0` or `OFF`) and make it kebab-case: `--enable-ggml-cuda`.
For options which require arguments like `CMAKE_CUDA_ARCHITECTURES`, the README says `-DCMAKE_CUDA_ARCHITECTURES="86"`. You need strip `-D`, prepend `--`, make it kebab-case, append `=` and append argument: `--cmake-cuda-architectures="86"`.
Usage Usage
----- -----
@ -70,6 +44,17 @@ end
Some models are prepared up-front: Some models are prepared up-front:
```ruby
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model.pre_converted_models["base"].clear_cache
```
You also can use shorthand for pre-converted models: You also can use shorthand for pre-converted models:
```ruby ```ruby
@ -94,19 +79,6 @@ puts Whisper::Model.pre_converted_models.keys
# : # :
``` ```
You can also retrieve each model:
```ruby
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model.pre_converted_models["base"].clear_cache
```
You can also use local model files you prepared: You can also use local model files you prepared:
```ruby ```ruby
@ -127,80 +99,9 @@ See [models][] page for details.
Currently, whisper.cpp accepts only 16-bit WAV files. Currently, whisper.cpp accepts only 16-bit WAV files.
### Voice Activity Detection (VAD) ###
Support for Voice Activity Detection (VAD) can be enabled by setting `Whisper::Params`'s `vad` argument to `true` and specifying VAD model:
```ruby
Whisper::Params.new(
vad: true,
vad_model_path: "silero-v5.1.2",
# other arguments...
)
```
When you pass the model name (`"silero-v5.1.2"`) or URI (`https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin`), it will be downloaded automatically.
Currently, "silero-v5.1.2" is registered as pre-converted model like ASR models. You also specify file path or URI of model.
If you need configure VAD behavior, pass params for that:
```ruby
Whisper::Params.new(
vad: true,
vad_model_path: "silero-v5.1.2",
vad_params: Whisper::VAD::Params.new(
threshold: 1.0, # defaults to 0.5
min_speech_duration_ms: 500, # defaults to 250
min_silence_duration_ms: 200, # defaults to 100
max_speech_duration_s: 30000, # default is FLT_MAX,
speech_pad_ms: 50, # defaults to 30
samples_overlap: 0.5 # defaults to 0.1
),
# other arguments...
)
```
For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisper.cpp?tab=readme-ov-file#voice-activity-detection-vad).
### Output ###
whispercpp supports SRT and WebVTT output:
```ruby
puts whisper.transcribe("path/to/audio.wav", Whisper::Params.new).to_webvtt
# =>
WEBVTT
1
00:00:00.000 --> 00:00:03.860
My thought I have nobody by a beauty and will as you poured.
2
00:00:03.860 --> 00:00:09.840
Mr. Rochester is sub in that so-don't find simplest, and devoted about, to let might in
3
00:00:09.840 --> 00:00:09.940
a
```
You may call `#to_srt`, too
API API
--- ---
### Transcription ###
By default, `Whisper::Context#transcribe` works in a single thread. You can make it work in parallel by passing `n_processors` option:
```ruby
whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors)
```
Note that transcription occasionally might be low accuracy when it works in parallel.
### Segments ### ### Segments ###
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
@ -222,7 +123,7 @@ whisper
ed: format_time(segment.end_time), ed: format_time(segment.end_time),
text: segment.text text: segment.text
} }
line << " (speaker turned)" if segment.speaker_turn_next? line << " (speaker turned)" if segment.speaker_next_turn?
puts line puts line
end end
@ -238,7 +139,7 @@ params.on_new_segment do |segment|
ed: format_time(segment.end_time), ed: format_time(segment.end_time),
text: segment.text text: segment.text
} }
line << " (speaker turned)" if segment.speaker_turn_next? line << " (speaker turned)" if segment.speaker_next_turn?
puts line puts line
end end
@ -327,7 +228,7 @@ The second argument `samples` may be an array, an object with `length` and `each
Development Development
----------- -----------
% git clone https://github.com/ggml-org/whisper.cpp.git % git clone https://github.com/ggerganov/whisper.cpp.git
% cd whisper.cpp/bindings/ruby % cd whisper.cpp/bindings/ruby
% rake test % rake test
@ -335,15 +236,10 @@ First call of `rake test` builds an extension and downloads a model for testing.
If something seems wrong on build, running `rake clean` solves some cases. If something seems wrong on build, running `rake clean` solves some cases.
### Need help ###
* Windows support
* Refinement of C/C++ code, especially memory management
License License
------- -------
The same to [whisper.cpp][]. The same to [whisper.cpp][].
[whisper.cpp]: https://github.com/ggml-org/whisper.cpp [whisper.cpp]: https://github.com/ggerganov/whisper.cpp
[models]: https://github.com/ggml-org/whisper.cpp/tree/master/models [models]: https://github.com/ggerganov/whisper.cpp/tree/master/models

View File

@ -3,15 +3,11 @@ require "bundler/gem_tasks"
require "rake/testtask" require "rake/testtask"
require_relative "extsources" require_relative "extsources"
SOURCES_DIR = "ext/sources"
SOURCES = FileList[] SOURCES = FileList[]
EXTSOURCES.each do |src| EXTSOURCES.each do |src|
basename = src.pathmap("%f") basename = src.pathmap("%f")
dest = basename == "LICENSE" ? basename dest = basename == "LICENSE" ? basename : src.pathmap("%{../..,ext}p")
: src.pathmap("%{\\.\\./\\.\\.,#{SOURCES_DIR}}p")
.pathmap("%{\\.\\./javascript,#{SOURCES_DIR}/bindings/javascript}p")
dir = dest.pathmap("%d") dir = dest.pathmap("%d")
file src file src
directory dir directory dir
@ -22,6 +18,7 @@ EXTSOURCES.each do |src|
end end
CLEAN.include SOURCES CLEAN.include SOURCES
CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
SRC = FileList["ext/*.{c,cpp,h}"] SRC = FileList["ext/*.{c,cpp,h}"]
@ -39,20 +36,6 @@ file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
ruby "extconf.rb" ruby "extconf.rb"
end end
end end
if File.exist? "ext/Makefile"
task :make_clean do
cd "ext" do
sh "make", "clean"
end
end
task clean: :make_clean
task :make_distclean do
cd "ext" do
sh "make", "distclean"
end
end
task clobber: :make_distclean
end
file SO_FILE => "ext/Makefile" do |t| file SO_FILE => "ext/Makefile" do |t|
chdir "ext" do chdir "ext" do
@ -67,30 +50,17 @@ file LIB_FILE => [SO_FILE, "lib"] do |t|
end end
CLEAN.include LIB_FILE CLEAN.include LIB_FILE
Rake::TestTask.new Rake::TestTask.new do |t|
t.test_files = FileList["tests/test_*.rb"]
TEST_FIXTURE_AUDIO = "test/fixtures/jfk.wav"
TEST_FIXTURE_AUDIO_SRC = File.expand_path(File.join(__dir__, "..", "..", "samples", "jfk.wav"))
TEST_FIXTURE_AUDIO_DIR = TEST_FIXTURE_AUDIO.pathmap("%d")
directory TEST_FIXTURE_AUDIO_DIR
if File.exist? TEST_FIXTURE_AUDIO_SRC
file TEST_FIXTURE_AUDIO => [TEST_FIXTURE_AUDIO_SRC, TEST_FIXTURE_AUDIO_DIR] do |t|
symlink t.source, t.name
end
else
require "open-uri"
file TEST_FIXTURE_AUDIO => TEST_FIXTURE_AUDIO_DIR do |t|
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/samples/jfk.wav").read
end
end end
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
chdir "test/jfk_reader" do chdir "tests/jfk_reader" do
ruby "extconf.rb" ruby "extconf.rb"
sh "make" sh "make"
end end
end end
CLEAN.include TEST_MEMORY_VIEW CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO] task test: [LIB_FILE, TEST_MEMORY_VIEW]

View File

@ -2,8 +2,10 @@ Makefile
whisper.so whisper.so
whisper.bundle whisper.bundle
whisper.dll whisper.dll
scripts/get-flags.mk
*.o *.o
*.a /*/**/*.c
sources/* /*/**/*.cpp
!sources/CMakeGraphVizOptions.cmake /*/**/*.h
mkmf.log /*/**/*.m
/*/**/*.metal

13
bindings/ruby/ext/cpu.mk Normal file
View File

@ -0,0 +1,13 @@
ggml/src/ggml-cpu/ggml-cpu-cpp.o: \
ggml/src/ggml-cpu/ggml-cpu.cpp \
ggml/src/ggml-cpu/unary-ops.cpp \
ggml/src/ggml-cpu/binary-ops.cpp \
ggml/src/ggml-cpu/vec.cpp \
ggml/src/ggml-cpu/ops.cpp \
ggml/include/ggml-backend.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/src/ggml-backend-impl.h \
ggml/include/ggml-cpu.h \
ggml/src/ggml-impl.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -1,73 +0,0 @@
require "tsort"
class Dependencies
include TSort
def initialize(cmake, options)
@cmake = cmake
@options = options
@static_lib_shape = nil
@nodes = {}
@graph = Hash.new {|h, k| h[k] = []}
generate_dot
parse_dot
end
def libs
tsort.filter_map {|node|
label, shape = @nodes[node]
if shape == @static_lib_shape
label.gsub(/\\n\([^)]+\)/, '')
else
nil
end
}.reverse.collect {|lib| "lib#{lib}.a"}
end
def to_s
libs.join(" ")
end
private
def dot_path
File.join(__dir__, "build", "whisper.cpp.dot")
end
def generate_dot
args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"]
args << @options.to_s unless @options.to_s.empty?
system @cmake, *args, exception: true
end
def parse_dot
File.open(dot_path).each_line do |line|
case line
when /\[\s*label\s*=\s*"Static Library"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]/
@static_lib_shape = $~[:shape]
when /\A\s*"(?<node>\w+)"\s*\[\s*label\s*=\s*"(?<label>\S+)"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]\s*;\s*\z/
node = $~[:node]
label = $~[:label]
shape = $~[:shape]
@nodes[node] = [label, shape]
when /\A\s*"(?<depender>\w+)"\s*->\s*"(?<dependee>\w+)"/
depender = $~[:depender]
dependee = $~[:dependee]
@graph[depender] << dependee
end
end
end
def tsort_each_node
@nodes.each_key do |node|
yield node
end
end
def tsort_each_child(node)
@graph[node].each do |child|
yield child
end
end
end

View File

@ -1,22 +1,212 @@
require "mkmf" require 'mkmf'
require_relative "options"
require_relative "dependencies"
cmake = find_executable("cmake") || abort # need to use c++ compiler flags
options = Options.new(cmake) $CXXFLAGS << ' -std=c++17'
have_library("gomp") rescue nil
libs = Dependencies.new(cmake, options)
$INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" $LDFLAGS << ' -lstdc++'
$LOCAL_LIBS << " #{libs}"
$cleanfiles << " build #{libs}"
create_makefile "whisper" do |conf| # Set to true when building binary gems
conf << <<~EOF if enable_config('static-stdlib', false)
$(TARGET_SO): #{libs} $LDFLAGS << ' -static-libgcc -static-libstdc++'
#{libs}: cmake-targets end
cmake-targets:
#{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} if enable_config('march-tune-native', false)
#{"\t"}#{cmake} --build build --config Release --target common whisper $CFLAGS << ' -march=native -mtune=native'
EOF $CXXFLAGS << ' -march=native -mtune=native'
end
if ENV['WHISPER_METAL']
$GGML_METAL ||= true
$DEPRECATE_WARNING ||= true
end
$UNAME_S = `uname -s`.chomp
$UNAME_P = `uname -p`.chomp
$UNAME_M = `uname -m`.chomp
if $UNAME_S == 'Darwin'
unless ENV['GGML_NO_METAL']
$GGML_METAL ||= true
end
$GGML_NO_OPENMP ||= true
end
if $GGML_METAL
$GGML_METAL_EMBED_LIBRARY = true
end
$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples -DGGML_USE_CPU'
$MK_CFLAGS = '-std=c11 -fPIC'
$MK_CXXFLAGS = '-std=c++17 -fPIC'
$MK_NVCCFLAGS = '-std=c++17'
$MK_LDFLAGS = ''
$OBJ_GGML = []
$OBJ_WHISPER = []
$OBJ_COMMON = []
$OBJ_SDL = []
$MK_CPPFLAGS << ' -D_XOPEN_SOURCE=600'
if $UNAME_S == 'Linux'
$MK_CPPFLAGS << ' -D_GNU_SOURCE'
end
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -D_DARWIN_C_SOURCE'
end
if ENV['WHISPER_DEBUG']
$MK_CFLAGS << ' -O0 -g'
$MK_CXXFLAGS << ' -O0 -g'
$MK_LDFLAGS << ' -g'
$MK_NVCCFLAGS << ' -O0 -g'
else
$MK_CPPFLAGS << ' -DNDEBUG'
$MK_CFLAGS << ' -O3'
$MK_CXXFLAGS << ' -O3'
$MK_NVCCFLAGS << ' -O3'
end
$WARN_FLAGS =
' -Wall' <<
' -Wextra' <<
' -Wpedantic' <<
' -Wcast-qual' <<
' -Wno-unused-function'
$MK_CFLAGS <<
$WARN_FLAGS <<
' -Wshadow' <<
' -Wstrict-prototypes' <<
' -Wpointer-arith' <<
' -Wmissing-prototypes' <<
' -Werror=implicit-int' <<
' -Werror=implicit-function-declaration'
$MK_CXXFLAGS <<
$WARN_FLAGS <<
' -Wmissing-declarations' <<
' -Wmissing-noreturn'
unless `#{cc_command} #{$LDFLAGS} -Wl,-v 2>&1`.chomp.include? 'dyld-1015.7'
$MK_CPPFLAGS << ' -DHAVE_BUGGY_APPLE_LINKER'
end
if %w[Linux Darwin FreeBSD NetBSD OpenBSD Haiku].include? $UNAME_S
$MK_CFLAGS << ' -pthread'
$MK_CXXFLAGS << ' -pthread'
end
unless $_WIN32
$DSO_EXT = '.so'
else
$DSO_EXT = '.dll'
end
unless ENV['RISCV']
if %w[x86_64 i686 amd64].include? $UNAME_M
$HOST_CXXFLAGS ||= ''
$MK_CFLAGS << ' -march=native -mtune=native'
$HOST_CXXFLAGS << ' -march=native -mtune=native'
end
else
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
end
unless ENV['GGML_NO_ACCELERATE']
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE'
$MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
$MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
$MK_LDFLAGS << ' -framework Accelerate'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
end
if ENV['GGML_OPENBLAS']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if ENV['GGML_OPENBLAS64']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas64`}"
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if $GGML_METAL
$MK_CPPFLAGS << ' -DGGML_USE_METAL'
$MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal.o'
if ENV['GGML_METAL_NDEBUG']
$MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
end
if $GGML_METAL_EMBED_LIBRARY
$MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal-embed.o'
end
end
$OBJ_GGML <<
'ggml/src/ggml.o' <<
'ggml/src/ggml-alloc.o' <<
'ggml/src/ggml-backend.o' <<
'ggml/src/ggml-backend-reg.o' <<
'ggml/src/ggml-opt.o' <<
'ggml/src/ggml-quants.o' <<
'ggml/src/ggml-threading.o' <<
'ggml/src/ggml-cpu/ggml-cpu.o' <<
'ggml/src/ggml-cpu/ggml-cpu-cpp.o' <<
'ggml/src/ggml-cpu/ggml-cpu-aarch64.o' <<
'ggml/src/ggml-cpu/ggml-cpu-hbm.o' <<
'ggml/src/ggml-cpu/ggml-cpu-quants.o' <<
'ggml/src/ggml-cpu/ggml-cpu-traits.o' <<
'ggml/src/ggml-cpu/unary-ops.o' <<
'ggml/src/ggml-cpu/binary-ops.o' <<
'ggml/src/ggml-cpu/vec.o' <<
'ggml/src/ggml-cpu/ops.o'
$OBJ_WHISPER <<
'src/whisper.o' <<
'examples/common.o' <<
'examples/common-whisper.o'
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
$objs <<
"ruby_whisper.o" <<
"ruby_whisper_context.o" <<
"ruby_whisper_transcribe.o" <<
"ruby_whisper_params.o" <<
"ruby_whisper_error.o" <<
"ruby_whisper_segment.o" <<
"ruby_whisper_model.o"
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"
$BASE_CXXFLAGS = "#{$MK_CXXFLAGS} #{$CXXFLAGS}"
$CXXFLAGS = "#{$BASE_CXXFLAGS} #{$HOST_CXXFLAGS} #{$GF_CXXFLAGS} #{$CPPFLAGS}"
$NVCCFLAGS = "#{$MK_NVCCFLAGS} #{$NVCCFLAGS}"
$LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}"
create_makefile('whisper')
File.open 'Makefile', 'a' do |file|
file.puts 'include scripts/get-flags.mk'
file.puts 'include cpu.mk'
if $GGML_METAL
file.puts 'include metal.mk'
if $GGML_METAL_EMBED_LIBRARY
file.puts 'include metal-embed.mk'
end
end
end end

View File

@ -0,0 +1,17 @@
ggml/src/ggml-metal/ggml-metal-embed.o: \
ggml/src/ggml-metal/ggml-metal.metal \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/src/ggml-common.h
@echo "Embedding Metal library"
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
$(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@
@rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s
@rmdir ${TEMP_ASSEMBLY}

View File

@ -0,0 +1,6 @@
ggml/src/ggml-metal/ggml-metal.o: \
ggml/src/ggml-metal/ggml-metal.m \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h
$(CC) $(CFLAGS) -c $< -o $@

View File

@ -1,85 +0,0 @@
class Options
def initialize(cmake="cmake")
@cmake = cmake
@options = {}
configure
end
def to_s
@options
.reject {|name, (type, value)| value.nil?}
.collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"}
.join(" ")
end
def cmake_options
return @cmake_options if @cmake_options
output = nil
Dir.chdir __dir__ do
output = `#{@cmake.shellescape} -S sources -B build -L`
end
@cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1)
.filter_map {|line|
option, value = line.chomp.split("=", 2)
name, type = option.split(":", 2)
[
name,
[
type,
type == "BOOL" ? value == "ON" : value
]
]
}.to_h
end
private
def configure
cmake_options.each_pair do |name, (type, default_value)|
option = option_name(name)
value = type == "BOOL" ? enable_config(option) : arg_config("--#{option}")
@options[name] = [type, value]
end
configure_accelerate
configure_metal
configure_coreml
end
# See ggml/src/ggml-cpu/CMakeLists.txt
def configure_accelerate
if RUBY_PLATFORM.match?(/darwin/) && enabled?("GGML_ACCELERATE")
$LDFLAGS << " -framework Accelerate"
end
end
# See ggml/src/ggml-metal/CMakeLists.txt
def configure_metal
$LDFLAGS << " -framework Foundation -framework Metal -framework MetalKit" if enabled?("GGML_METAL")
end
# See src/CmakeLists.txt
def configure_coreml
if enabled?("WHISPER_COREML")
$LDFLAGS << " -framework Foundation -framework CoreML"
$defs << "-DRUBY_WHISPER_USE_COREML"
end
end
def option_name(name)
name.downcase.gsub("_", "-")
end
def enabled?(option)
op = @options[option]
raise "Option not exist: #{option}" unless op
raise "Option not boolean: #{option}(#{op[0]})" unless op[0] == "BOOL"
if op[1].nil?
cmake_options[option][1]
else
op[1]
end
end
end

View File

@ -3,10 +3,8 @@
#include "ruby_whisper.h" #include "ruby_whisper.h"
VALUE mWhisper; VALUE mWhisper;
VALUE mVAD;
VALUE cContext; VALUE cContext;
VALUE cParams; VALUE cParams;
VALUE cVADParams;
VALUE eError; VALUE eError;
VALUE cSegment; VALUE cSegment;
@ -22,9 +20,6 @@ ID id_new;
ID id_to_path; ID id_to_path;
ID id_URI; ID id_URI;
ID id_pre_converted_models; ID id_pre_converted_models;
ID id_coreml_compiled_models;
ID id_cache;
ID id_n_processors;
static bool is_log_callback_finalized = false; static bool is_log_callback_finalized = false;
@ -36,7 +31,6 @@ extern void init_ruby_whisper_params(VALUE *mWhisper);
extern void init_ruby_whisper_error(VALUE *mWhisper); extern void init_ruby_whisper_error(VALUE *mWhisper);
extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment); extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
extern void init_ruby_whisper_model(VALUE *mWhisper); extern void init_ruby_whisper_model(VALUE *mWhisper);
extern void init_ruby_whisper_vad_params(VALUE *mVAD);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
/* /*
@ -86,14 +80,6 @@ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
return rb_str_new2(str_full); return rb_str_new2(str_full);
} }
/*
* call-seq:
* system_info_str -> String
*/
static VALUE ruby_whisper_s_system_info_str(VALUE self) {
return rb_str_new2(whisper_print_system_info());
}
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
is_log_callback_finalized = true; is_log_callback_finalized = true;
return Qnil; return Qnil;
@ -130,6 +116,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
return Qnil; return Qnil;
} }
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm;
rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
}
void Init_whisper() { void Init_whisper() {
id_to_s = rb_intern("to_s"); id_to_s = rb_intern("to_s");
id_call = rb_intern("call"); id_call = rb_intern("call");
@ -141,14 +137,9 @@ void Init_whisper() {
id_to_path = rb_intern("to_path"); id_to_path = rb_intern("to_path");
id_URI = rb_intern("URI"); id_URI = rb_intern("URI");
id_pre_converted_models = rb_intern("pre_converted_models"); id_pre_converted_models = rb_intern("pre_converted_models");
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
id_cache = rb_intern("cache");
id_n_processors = rb_intern("n_processors");
mWhisper = rb_define_module("Whisper"); mWhisper = rb_define_module("Whisper");
mVAD = rb_define_module_under(mWhisper, "VAD");
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO)); rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN)); rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
@ -160,7 +151,6 @@ void Init_whisper() {
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
@ -169,9 +159,6 @@ void Init_whisper() {
init_ruby_whisper_error(&mWhisper); init_ruby_whisper_error(&mWhisper);
init_ruby_whisper_segment(&mWhisper, &cContext); init_ruby_whisper_segment(&mWhisper, &cContext);
init_ruby_whisper_model(&mWhisper); init_ruby_whisper_model(&mWhisper);
init_ruby_whisper_vad_params(&mVAD);
rb_require("whisper/context");
rb_require("whisper/segment");
rb_require("whisper/model/uri"); rb_require("whisper/model/uri");
} }

View File

@ -19,15 +19,9 @@ typedef struct {
bool diarize; bool diarize;
ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *new_segment_callback_container;
ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *progress_callback_container;
ruby_whisper_callback_container *encoder_begin_callback_container;
ruby_whisper_callback_container *abort_callback_container; ruby_whisper_callback_container *abort_callback_container;
VALUE vad_params;
} ruby_whisper_params; } ruby_whisper_params;
typedef struct {
struct whisper_vad_params params;
} ruby_whisper_vad_params;
typedef struct { typedef struct {
VALUE context; VALUE context;
int index; int index;

View File

@ -11,21 +11,15 @@ extern ID id_new;
extern ID id_to_path; extern ID id_to_path;
extern ID id_URI; extern ID id_URI;
extern ID id_pre_converted_models; extern ID id_pre_converted_models;
extern ID id_coreml_compiled_models;
extern ID id_cache;
extern ID id_n_processors;
extern VALUE cContext; extern VALUE cContext;
extern VALUE eError; extern VALUE eError;
extern VALUE cModel; extern VALUE cModel;
extern const rb_data_type_t ruby_whisper_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_model_initialize(VALUE context);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index); extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
ID transcribe_option_names[1];
static void static void
ruby_whisper_free(ruby_whisper *rw) ruby_whisper_free(ruby_whisper *rw)
@ -43,74 +37,19 @@ rb_whisper_mark(ruby_whisper *rw)
} }
void void
rb_whisper_free(void *p) rb_whisper_free(ruby_whisper *rw)
{ {
ruby_whisper *rw = (ruby_whisper *)p;
ruby_whisper_free(rw); ruby_whisper_free(rw);
free(rw); free(rw);
} }
static size_t
ruby_whisper_memsize(const void *p)
{
const ruby_whisper *rw = (const ruby_whisper *)p;
size_t size = sizeof(rw);
if (!rw) {
return 0;
}
if (rw->context) {
size += sizeof(rw->context);
}
return size;
}
const rb_data_type_t ruby_whisper_type = {
"ruby_whisper",
{0, rb_whisper_free, ruby_whisper_memsize,},
0, 0,
0
};
static VALUE static VALUE
ruby_whisper_allocate(VALUE klass) ruby_whisper_allocate(VALUE klass)
{ {
ruby_whisper *rw; ruby_whisper *rw;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper, &ruby_whisper_type, rw); rw = ALLOC(ruby_whisper);
rw->context = NULL; rw->context = NULL;
return obj; return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
}
VALUE
ruby_whisper_normalize_model_path(VALUE model_path)
{
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, model_path);
if (!NIL_P(pre_converted_model)) {
model_path = pre_converted_model;
#ifdef RUBY_WHISPER_USE_COREML
VALUE coreml_converted_models = rb_funcall(cModel, id_coreml_compiled_models, 0);
VALUE coreml_converted_model = rb_hash_aref(coreml_converted_models, pre_converted_model);
if (!NIL_P(coreml_converted_model)) {
rb_funcall(coreml_converted_model, id_cache, 0);
}
#endif
}
else if (TYPE(model_path) == T_STRING) {
const char * model_path_str = StringValueCStr(model_path);
if (strncmp("http://", model_path_str, 7) == 0 || strncmp("https://", model_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
model_path = rb_class_new_instance(1, &model_path, uri_class);
}
}
else if (rb_obj_is_kind_of(model_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
model_path = rb_class_new_instance(1, &model_path, uri_class);
}
if (rb_respond_to(model_path, id_to_path)) {
model_path = rb_funcall(model_path, id_to_path, 0);
}
return model_path;
} }
/* /*
@ -127,9 +66,27 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
// TODO: we can support init from buffer here too maybe another ruby object to expose // TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path); rb_scan_args(argc, argv, "01", &whisper_model_file_path);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (TYPE(whisper_model_file_path) == T_STRING) {
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
}
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
if (!rb_respond_to(whisper_model_file_path, id_to_s)) { if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
} }
@ -147,7 +104,7 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
VALUE ruby_whisper_model_n_vocab(VALUE self) VALUE ruby_whisper_model_n_vocab(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context)); return INT2NUM(whisper_model_n_vocab(rw->context));
} }
@ -158,7 +115,7 @@ VALUE ruby_whisper_model_n_vocab(VALUE self)
VALUE ruby_whisper_model_n_audio_ctx(VALUE self) VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context)); return INT2NUM(whisper_model_n_audio_ctx(rw->context));
} }
@ -169,7 +126,7 @@ VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
VALUE ruby_whisper_model_n_audio_state(VALUE self) VALUE ruby_whisper_model_n_audio_state(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context)); return INT2NUM(whisper_model_n_audio_state(rw->context));
} }
@ -180,7 +137,7 @@ VALUE ruby_whisper_model_n_audio_state(VALUE self)
VALUE ruby_whisper_model_n_audio_head(VALUE self) VALUE ruby_whisper_model_n_audio_head(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context)); return INT2NUM(whisper_model_n_audio_head(rw->context));
} }
@ -191,7 +148,7 @@ VALUE ruby_whisper_model_n_audio_head(VALUE self)
VALUE ruby_whisper_model_n_audio_layer(VALUE self) VALUE ruby_whisper_model_n_audio_layer(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context)); return INT2NUM(whisper_model_n_audio_layer(rw->context));
} }
@ -202,7 +159,7 @@ VALUE ruby_whisper_model_n_audio_layer(VALUE self)
VALUE ruby_whisper_model_n_text_ctx(VALUE self) VALUE ruby_whisper_model_n_text_ctx(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context)); return INT2NUM(whisper_model_n_text_ctx(rw->context));
} }
@ -213,7 +170,7 @@ VALUE ruby_whisper_model_n_text_ctx(VALUE self)
VALUE ruby_whisper_model_n_text_state(VALUE self) VALUE ruby_whisper_model_n_text_state(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context)); return INT2NUM(whisper_model_n_text_state(rw->context));
} }
@ -224,7 +181,7 @@ VALUE ruby_whisper_model_n_text_state(VALUE self)
VALUE ruby_whisper_model_n_text_head(VALUE self) VALUE ruby_whisper_model_n_text_head(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context)); return INT2NUM(whisper_model_n_text_head(rw->context));
} }
@ -235,7 +192,7 @@ VALUE ruby_whisper_model_n_text_head(VALUE self)
VALUE ruby_whisper_model_n_text_layer(VALUE self) VALUE ruby_whisper_model_n_text_layer(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context)); return INT2NUM(whisper_model_n_text_layer(rw->context));
} }
@ -246,7 +203,7 @@ VALUE ruby_whisper_model_n_text_layer(VALUE self)
VALUE ruby_whisper_model_n_mels(VALUE self) VALUE ruby_whisper_model_n_mels(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context)); return INT2NUM(whisper_model_n_mels(rw->context));
} }
@ -257,7 +214,7 @@ VALUE ruby_whisper_model_n_mels(VALUE self)
VALUE ruby_whisper_model_ftype(VALUE self) VALUE ruby_whisper_model_ftype(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context)); return INT2NUM(whisper_model_ftype(rw->context));
} }
@ -268,7 +225,7 @@ VALUE ruby_whisper_model_ftype(VALUE self)
VALUE ruby_whisper_model_type(VALUE self) VALUE ruby_whisper_model_type(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context)); return rb_str_new2(whisper_model_type_readable(rw->context));
} }
@ -291,9 +248,9 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
ruby_whisper *rw; ruby_whisper *rw;
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0]; VALUE params = argv[0];
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1]; VALUE samples = argv[1];
int n_samples; int n_samples;
rb_memory_view_t view; rb_memory_view_t view;
@ -308,20 +265,13 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
// Should check when samples.respond_to?(:length)? // Should check when samples.respond_to?(:length)?
} else { } else {
if (TYPE(samples) == T_ARRAY) { if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) > INT_MAX) { n_samples = RARRAY_LEN(samples);
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)RARRAY_LEN(samples);
} else if (memory_view_available_p) { } else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil; view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view"); rb_raise(rb_eArgError, "unable to get a memory view");
} }
ssize_t n_samples_size = view.byte_size / view.item_size; n_samples = view.byte_size / view.item_size;
if (n_samples_size > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)n_samples_size;
} else if (rb_respond_to(samples, id_length)) { } else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else { } else {
@ -346,7 +296,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
} }
} }
} }
prepare_transcription(rwp, &self); register_callbacks(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) { if (0 == result) {
return self; return self;
@ -377,9 +327,9 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
ruby_whisper *rw; ruby_whisper *rw;
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0]; VALUE params = argv[0];
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1]; VALUE samples = argv[1];
int n_samples; int n_samples;
int n_processors; int n_processors;
@ -409,17 +359,10 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
view.obj = Qnil; view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view"); rb_raise(rb_eArgError, "unable to get a memory view");
} }
ssize_t n_samples_size = view.byte_size / view.item_size; n_samples = view.byte_size / view.item_size;
if (n_samples_size > INT_MAX) {
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)n_samples_size;
} else { } else {
if (TYPE(samples) == T_ARRAY) { if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) > INT_MAX) { n_samples = RARRAY_LEN(samples);
rb_raise(rb_eArgError, "samples are too long");
}
n_samples = (int)RARRAY_LEN(samples);
} else if (rb_respond_to(samples, id_length)) { } else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else { } else {
@ -444,7 +387,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
} }
} }
} }
prepare_transcription(rwp, &self); register_callbacks(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) { if (0 == result) {
return self; return self;
@ -463,7 +406,7 @@ static VALUE
ruby_whisper_full_n_segments(VALUE self) ruby_whisper_full_n_segments(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_n_segments(rw->context)); return INT2NUM(whisper_full_n_segments(rw->context));
} }
@ -477,7 +420,7 @@ static VALUE
ruby_whisper_full_lang_id(VALUE self) ruby_whisper_full_lang_id(VALUE self)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_lang_id(rw->context)); return INT2NUM(whisper_full_lang_id(rw->context));
} }
@ -502,10 +445,10 @@ static VALUE
ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
return LONG2NUM(t0); return INT2NUM(t0);
} }
/* /*
@ -520,10 +463,10 @@ static VALUE
ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
return LONG2NUM(t1); return INT2NUM(t1);
} }
/* /*
@ -538,7 +481,7 @@ static VALUE
ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
return speaker_turn_next ? Qtrue : Qfalse; return speaker_turn_next ? Qtrue : Qfalse;
@ -556,7 +499,7 @@ static VALUE
ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
return rb_str_new2(text); return rb_str_new2(text);
@ -570,7 +513,7 @@ static VALUE
ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
{ {
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment); const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob); return DBL2NUM(no_speech_prob);
@ -581,7 +524,7 @@ ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
static VALUE static VALUE
ruby_whisper_full_get_segment(VALUE self, VALUE i_segment) ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
{ {
return rb_whisper_segment_s_new(self, NUM2INT(i_segment)); return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
} }
/* /*
@ -611,11 +554,11 @@ ruby_whisper_each_segment(VALUE self)
} }
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(self, ruby_whisper, rw);
const int n_segments = whisper_full_n_segments(rw->context); const int n_segments = whisper_full_n_segments(rw->context);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
rb_yield(rb_whisper_segment_s_new(self, i)); rb_yield(rb_whisper_segment_initialize(self, i));
} }
return self; return self;
@ -628,7 +571,7 @@ ruby_whisper_each_segment(VALUE self)
static VALUE static VALUE
ruby_whisper_get_model(VALUE self) ruby_whisper_get_model(VALUE self)
{ {
return rb_whisper_model_s_new(self); return rb_whisper_model_initialize(self);
} }
void void
@ -636,8 +579,6 @@ init_ruby_whisper_context(VALUE *mWhisper)
{ {
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject); cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
transcribe_option_names[0] = id_n_processors;
rb_define_alloc_func(cContext, ruby_whisper_allocate); rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
@ -664,7 +605,7 @@ init_ruby_whisper_context(VALUE *mWhisper)
rb_define_method(cContext, "full", ruby_whisper_full, -1); rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1); rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
// High level // High leve
rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1); rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);

View File

@ -1,44 +1,22 @@
#include <ruby.h> #include <ruby.h>
#include "ruby_whisper.h" #include "ruby_whisper.h"
extern const rb_data_type_t ruby_whisper_type;
extern VALUE cModel; extern VALUE cModel;
static void rb_whisper_model_mark(void *p) { static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
ruby_whisper_model *rwm = (ruby_whisper_model *)p;
if (rwm->context) {
rb_gc_mark(rwm->context); rb_gc_mark(rwm->context);
} }
}
static size_t
ruby_whisper_model_memsize(const void *p)
{
const ruby_whisper_model *rwm = (const ruby_whisper_model *)p;
size_t size = sizeof(rwm);
if (!rwm) {
return 0;
}
return size;
}
static const rb_data_type_t rb_whisper_model_type = {
"ruby_whisper_model",
{rb_whisper_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_model_memsize,},
0, 0,
0
};
static VALUE ruby_whisper_model_allocate(VALUE klass) { static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
return TypedData_Make_Struct(klass, ruby_whisper_model, &rb_whisper_model_type, rwm); rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
} }
VALUE rb_whisper_model_s_new(VALUE context) { VALUE rb_whisper_model_initialize(VALUE context) {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
const VALUE model = ruby_whisper_model_allocate(cModel); const VALUE model = ruby_whisper_model_allocate(cModel);
TypedData_Get_Struct(model, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(model, ruby_whisper_model, rwm);
rwm->context = context; rwm->context = context;
return model; return model;
}; };
@ -51,9 +29,9 @@ static VALUE
ruby_whisper_model_n_vocab(VALUE self) ruby_whisper_model_n_vocab(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context)); return INT2NUM(whisper_model_n_vocab(rw->context));
} }
@ -65,9 +43,9 @@ static VALUE
ruby_whisper_model_n_audio_ctx(VALUE self) ruby_whisper_model_n_audio_ctx(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context)); return INT2NUM(whisper_model_n_audio_ctx(rw->context));
} }
@ -79,9 +57,9 @@ static VALUE
ruby_whisper_model_n_audio_state(VALUE self) ruby_whisper_model_n_audio_state(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context)); return INT2NUM(whisper_model_n_audio_state(rw->context));
} }
@ -93,9 +71,9 @@ static VALUE
ruby_whisper_model_n_audio_head(VALUE self) ruby_whisper_model_n_audio_head(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context)); return INT2NUM(whisper_model_n_audio_head(rw->context));
} }
@ -107,9 +85,9 @@ static VALUE
ruby_whisper_model_n_audio_layer(VALUE self) ruby_whisper_model_n_audio_layer(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context)); return INT2NUM(whisper_model_n_audio_layer(rw->context));
} }
@ -121,9 +99,9 @@ static VALUE
ruby_whisper_model_n_text_ctx(VALUE self) ruby_whisper_model_n_text_ctx(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context)); return INT2NUM(whisper_model_n_text_ctx(rw->context));
} }
@ -135,9 +113,9 @@ static VALUE
ruby_whisper_model_n_text_state(VALUE self) ruby_whisper_model_n_text_state(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context)); return INT2NUM(whisper_model_n_text_state(rw->context));
} }
@ -149,9 +127,9 @@ static VALUE
ruby_whisper_model_n_text_head(VALUE self) ruby_whisper_model_n_text_head(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context)); return INT2NUM(whisper_model_n_text_head(rw->context));
} }
@ -163,9 +141,9 @@ static VALUE
ruby_whisper_model_n_text_layer(VALUE self) ruby_whisper_model_n_text_layer(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context)); return INT2NUM(whisper_model_n_text_layer(rw->context));
} }
@ -177,9 +155,9 @@ static VALUE
ruby_whisper_model_n_mels(VALUE self) ruby_whisper_model_n_mels(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context)); return INT2NUM(whisper_model_n_mels(rw->context));
} }
@ -191,9 +169,9 @@ static VALUE
ruby_whisper_model_ftype(VALUE self) ruby_whisper_model_ftype(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context)); return INT2NUM(whisper_model_ftype(rw->context));
} }
@ -205,9 +183,9 @@ static VALUE
ruby_whisper_model_type(VALUE self) ruby_whisper_model_type(VALUE self)
{ {
ruby_whisper_model *rwm; ruby_whisper_model *rwm;
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm); Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rwm->context, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context)); return rb_str_new2(whisper_model_type_readable(rw->context));
} }

View File

@ -3,7 +3,7 @@
#define BOOL_PARAMS_SETTER(self, prop, value) \ #define BOOL_PARAMS_SETTER(self, prop, value) \
ruby_whisper_params *rwp; \ ruby_whisper_params *rwp; \
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \ Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (value == Qfalse || value == Qnil) { \ if (value == Qfalse || value == Qnil) { \
rwp->params.prop = false; \ rwp->params.prop = false; \
} else { \ } else { \
@ -13,7 +13,7 @@
#define BOOL_PARAMS_GETTER(self, prop) \ #define BOOL_PARAMS_GETTER(self, prop) \
ruby_whisper_params *rwp; \ ruby_whisper_params *rwp; \
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \ Data_Get_Struct(self, ruby_whisper_params, rwp); \
if (rwp->params.prop) { \ if (rwp->params.prop) { \
return Qtrue; \ return Qtrue; \
} else { \ } else { \
@ -26,16 +26,13 @@
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 35 #define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
extern VALUE cParams; extern VALUE cParams;
extern VALUE cVADParams;
extern ID id_call; extern ID id_call;
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
extern const rb_data_type_t ruby_whisper_vad_params_type;
static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT]; static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
static ID id_language; static ID id_language;
@ -66,19 +63,12 @@ static ID id_new_segment_callback;
static ID id_new_segment_callback_user_data; static ID id_new_segment_callback_user_data;
static ID id_progress_callback; static ID id_progress_callback;
static ID id_progress_callback_user_data; static ID id_progress_callback_user_data;
static ID id_encoder_begin_callback;
static ID id_encoder_begin_callback_user_data;
static ID id_abort_callback; static ID id_abort_callback;
static ID id_abort_callback_user_data; static ID id_abort_callback_user_data;
static ID id_vad;
static ID id_vad_model_path;
static ID id_vad_params;
static void static void
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
{ {
if (rwc == NULL) return;
rb_gc_mark(rwc->user_data); rb_gc_mark(rwc->user_data);
rb_gc_mark(rwc->callback); rb_gc_mark(rwc->callback);
rb_gc_mark(rwc->callbacks); rb_gc_mark(rwc->callbacks);
@ -110,7 +100,7 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta
const int n_segments = whisper_full_n_segments_from_state(state); const int n_segments = whisper_full_n_segments_from_state(state);
for (int i = n_new; i > 0; i--) { for (int i = n_new; i > 0; i--) {
int i_segment = n_segments - i; int i_segment = n_segments - i;
VALUE segment = rb_whisper_segment_s_new(*container->context, i_segment); VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
for (int j = 0; j < callbacks_len; j++) { for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j); VALUE cb = rb_ary_entry(container->callbacks, j);
rb_funcall(cb, id_call, 1, segment); rb_funcall(cb, id_call, 1, segment);
@ -136,33 +126,6 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
} }
} }
static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
bool is_aborted = false;
VALUE result;
// Currently, doesn't support state because
// those require to resolve GC-related problems.
if (!NIL_P(container->callback)) {
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
if (result == Qfalse) {
is_aborted = true;
}
}
const long callbacks_len = RARRAY_LEN(container->callbacks);
if (0 == callbacks_len) {
return !is_aborted;
}
for (int j = 0; j < callbacks_len; j++) {
VALUE cb = rb_ary_entry(container->callbacks, j);
result = rb_funcall(cb, id_call, 0);
if (result == Qfalse) {
is_aborted = true;
}
}
return !is_aborted;
}
static bool abort_callback(void * user_data) { static bool abort_callback(void * user_data) {
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
if (!NIL_P(container->callback)) { if (!NIL_P(container->callback)) {
@ -185,7 +148,7 @@ static bool abort_callback(void * user_data) {
return false; return false;
} }
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
rwp->new_segment_callback_container->context = context; rwp->new_segment_callback_container->context = context;
rwp->params.new_segment_callback = new_segment_callback; rwp->params.new_segment_callback = new_segment_callback;
@ -198,12 +161,6 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
rwp->params.progress_callback_user_data = rwp->progress_callback_container; rwp->params.progress_callback_user_data = rwp->progress_callback_container;
} }
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
rwp->encoder_begin_callback_container->context = context;
rwp->params.encoder_begin_callback = encoder_begin_callback;
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
}
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
rwp->abort_callback_container->context = context; rwp->abort_callback_container->context = context;
rwp->params.abort_callback = abort_callback; rwp->params.abort_callback = abort_callback;
@ -211,29 +168,12 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
} }
} }
static void set_vad_params(ruby_whisper_params *rwp)
{
ruby_whisper_vad_params * rwvp;
TypedData_Get_Struct(rwp->vad_params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwp->params.vad_params = rwvp->params;
}
void void
prepare_transcription(ruby_whisper_params *rwp, VALUE *context) rb_whisper_params_mark(ruby_whisper_params *rwp)
{ {
register_callbacks(rwp, context);
set_vad_params(rwp);
}
void
rb_whisper_params_mark(void *p)
{
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
rb_whisper_callbcack_container_mark(rwp->abort_callback_container); rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
rb_gc_mark(rwp->vad_params);
} }
void void
@ -242,46 +182,24 @@ ruby_whisper_params_free(ruby_whisper_params *rwp)
} }
void void
rb_whisper_params_free(void *p) rb_whisper_params_free(ruby_whisper_params *rwp)
{ {
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
// How to free user_data and callback only when not referred to by others? // How to free user_data and callback only when not referred to by others?
ruby_whisper_params_free(rwp); ruby_whisper_params_free(rwp);
free(rwp); free(rwp);
} }
static size_t
ruby_whisper_params_memsize(const void *p)
{
const ruby_whisper_params *rwp = (const ruby_whisper_params *)p;
return sizeof(ruby_whisper_params) + sizeof(rwp->params) + sizeof(rwp->vad_params);
}
const rb_data_type_t ruby_whisper_params_type = {
"ruby_whisper_params",
{
rb_whisper_params_mark,
rb_whisper_params_free,
ruby_whisper_params_memsize,
},
0, 0,
0
};
static VALUE static VALUE
ruby_whisper_params_allocate(VALUE klass) ruby_whisper_params_allocate(VALUE klass)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->diarize = false; rwp->diarize = false;
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_callback_container_allocate();
return obj; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
} }
/* /*
@ -294,7 +212,7 @@ static VALUE
ruby_whisper_params_set_language(VALUE self, VALUE value) ruby_whisper_params_set_language(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) { if (value == Qfalse || value == Qnil) {
rwp->params.language = "auto"; rwp->params.language = "auto";
} else { } else {
@ -310,7 +228,7 @@ static VALUE
ruby_whisper_params_get_language(VALUE self) ruby_whisper_params_get_language(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->params.language) { if (rwp->params.language) {
return rb_str_new2(rwp->params.language); return rb_str_new2(rwp->params.language);
} else { } else {
@ -547,7 +465,7 @@ static VALUE
ruby_whisper_params_get_initial_prompt(VALUE self) ruby_whisper_params_get_initial_prompt(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt); return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
} }
/* /*
@ -558,7 +476,7 @@ static VALUE
ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.initial_prompt = StringValueCStr(value); rwp->params.initial_prompt = StringValueCStr(value);
return value; return value;
} }
@ -572,7 +490,7 @@ static VALUE
ruby_whisper_params_get_diarize(VALUE self) ruby_whisper_params_get_diarize(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
if (rwp->diarize) { if (rwp->diarize) {
return Qtrue; return Qtrue;
} else { } else {
@ -587,7 +505,7 @@ static VALUE
ruby_whisper_params_set_diarize(VALUE self, VALUE value) ruby_whisper_params_set_diarize(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
if (value == Qfalse || value == Qnil) { if (value == Qfalse || value == Qnil) {
rwp->diarize = false; rwp->diarize = false;
} else { } else {
@ -606,7 +524,7 @@ static VALUE
ruby_whisper_params_get_offset(VALUE self) ruby_whisper_params_get_offset(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.offset_ms); return INT2NUM(rwp->params.offset_ms);
} }
/* /*
@ -617,7 +535,7 @@ static VALUE
ruby_whisper_params_set_offset(VALUE self, VALUE value) ruby_whisper_params_set_offset(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.offset_ms = NUM2INT(value); rwp->params.offset_ms = NUM2INT(value);
return value; return value;
} }
@ -631,7 +549,7 @@ static VALUE
ruby_whisper_params_get_duration(VALUE self) ruby_whisper_params_get_duration(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.duration_ms); return INT2NUM(rwp->params.duration_ms);
} }
/* /*
@ -642,7 +560,7 @@ static VALUE
ruby_whisper_params_set_duration(VALUE self, VALUE value) ruby_whisper_params_set_duration(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.duration_ms = NUM2INT(value); rwp->params.duration_ms = NUM2INT(value);
return value; return value;
} }
@ -657,7 +575,7 @@ static VALUE
ruby_whisper_params_get_max_text_tokens(VALUE self) ruby_whisper_params_get_max_text_tokens(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return INT2NUM(rwp->params.n_max_text_ctx); return INT2NUM(rwp->params.n_max_text_ctx);
} }
/* /*
@ -668,7 +586,7 @@ static VALUE
ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.n_max_text_ctx = NUM2INT(value); rwp->params.n_max_text_ctx = NUM2INT(value);
return value; return value;
} }
@ -680,7 +598,7 @@ static VALUE
ruby_whisper_params_get_temperature(VALUE self) ruby_whisper_params_get_temperature(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.temperature); return DBL2NUM(rwp->params.temperature);
} }
/* /*
@ -691,7 +609,7 @@ static VALUE
ruby_whisper_params_set_temperature(VALUE self, VALUE value) ruby_whisper_params_set_temperature(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.temperature = RFLOAT_VALUE(value); rwp->params.temperature = RFLOAT_VALUE(value);
return value; return value;
} }
@ -705,7 +623,7 @@ static VALUE
ruby_whisper_params_get_max_initial_ts(VALUE self) ruby_whisper_params_get_max_initial_ts(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.max_initial_ts); return DBL2NUM(rwp->params.max_initial_ts);
} }
/* /*
@ -716,7 +634,7 @@ static VALUE
ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.max_initial_ts = RFLOAT_VALUE(value); rwp->params.max_initial_ts = RFLOAT_VALUE(value);
return value; return value;
} }
@ -728,7 +646,7 @@ static VALUE
ruby_whisper_params_get_length_penalty(VALUE self) ruby_whisper_params_get_length_penalty(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.length_penalty); return DBL2NUM(rwp->params.length_penalty);
} }
/* /*
@ -739,7 +657,7 @@ static VALUE
ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.length_penalty = RFLOAT_VALUE(value); rwp->params.length_penalty = RFLOAT_VALUE(value);
return value; return value;
} }
@ -751,7 +669,7 @@ static VALUE
ruby_whisper_params_get_temperature_inc(VALUE self) ruby_whisper_params_get_temperature_inc(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.temperature_inc); return DBL2NUM(rwp->params.temperature_inc);
} }
/* /*
@ -762,7 +680,7 @@ static VALUE
ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.temperature_inc = RFLOAT_VALUE(value); rwp->params.temperature_inc = RFLOAT_VALUE(value);
return value; return value;
} }
@ -776,7 +694,7 @@ static VALUE
ruby_whisper_params_get_entropy_thold(VALUE self) ruby_whisper_params_get_entropy_thold(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.entropy_thold); return DBL2NUM(rwp->params.entropy_thold);
} }
/* /*
@ -787,7 +705,7 @@ static VALUE
ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.entropy_thold = RFLOAT_VALUE(value); rwp->params.entropy_thold = RFLOAT_VALUE(value);
return value; return value;
} }
@ -799,7 +717,7 @@ static VALUE
ruby_whisper_params_get_logprob_thold(VALUE self) ruby_whisper_params_get_logprob_thold(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.logprob_thold); return DBL2NUM(rwp->params.logprob_thold);
} }
/* /*
@ -810,7 +728,7 @@ static VALUE
ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.logprob_thold = RFLOAT_VALUE(value); rwp->params.logprob_thold = RFLOAT_VALUE(value);
return value; return value;
} }
@ -822,7 +740,7 @@ static VALUE
ruby_whisper_params_get_no_speech_thold(VALUE self) ruby_whisper_params_get_no_speech_thold(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.no_speech_thold); return DBL2NUM(rwp->params.no_speech_thold);
} }
/* /*
@ -833,7 +751,7 @@ static VALUE
ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.no_speech_thold = RFLOAT_VALUE(value); rwp->params.no_speech_thold = RFLOAT_VALUE(value);
return value; return value;
} }
@ -841,7 +759,7 @@ static VALUE
ruby_whisper_params_get_new_segment_callback(VALUE self) ruby_whisper_params_get_new_segment_callback(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->new_segment_callback_container->callback; return rwp->new_segment_callback_container->callback;
} }
/* /*
@ -858,7 +776,7 @@ static VALUE
ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->callback = value; rwp->new_segment_callback_container->callback = value;
return value; return value;
} }
@ -866,7 +784,7 @@ static VALUE
ruby_whisper_params_get_new_segment_callback_user_data(VALUE self) ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->new_segment_callback_container->user_data; return rwp->new_segment_callback_container->user_data;
} }
/* /*
@ -879,7 +797,7 @@ static VALUE
ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->new_segment_callback_container->user_data = value; rwp->new_segment_callback_container->user_data = value;
return value; return value;
} }
@ -887,7 +805,7 @@ static VALUE
ruby_whisper_params_get_progress_callback(VALUE self) ruby_whisper_params_get_progress_callback(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->progress_callback_container->callback; return rwp->progress_callback_container->callback;
} }
/* /*
@ -906,7 +824,7 @@ static VALUE
ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->progress_callback_container->callback = value; rwp->progress_callback_container->callback = value;
return value; return value;
} }
@ -914,7 +832,7 @@ static VALUE
ruby_whisper_params_get_progress_callback_user_data(VALUE self) ruby_whisper_params_get_progress_callback_user_data(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->progress_callback_container->user_data; return rwp->progress_callback_container->user_data;
} }
/* /*
@ -927,66 +845,15 @@ static VALUE
ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->progress_callback_container->user_data = value; rwp->progress_callback_container->user_data = value;
return value; return value;
} }
static VALUE
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->encoder_begin_callback_container->callback;
}
/*
* Sets encoder begin callback, called when the encoder starts.
*
* params.encoder_begin_callback = ->(context, _, user_data) {
* # ...
* }
*
* call-seq:
* encoder_begin_callback = callback -> callback
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->encoder_begin_callback_container->callback = value;
return value;
}
static VALUE
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->encoder_begin_callback_container->user_data;
}
/*
* Sets user data passed to the last argument of encoder begin callback.
*
* call-seq:
* encoder_begin_callback_user_data = user_data -> use_data
*/
static VALUE
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->encoder_begin_callback_container->user_data = value;
return value;
}
static VALUE static VALUE
ruby_whisper_params_get_abort_callback(VALUE self) ruby_whisper_params_get_abort_callback(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->abort_callback_container->callback; return rwp->abort_callback_container->callback;
} }
/* /*
@ -1003,7 +870,7 @@ static VALUE
ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->abort_callback_container->callback = value; rwp->abort_callback_container->callback = value;
return value; return value;
} }
@ -1011,7 +878,7 @@ static VALUE
ruby_whisper_params_get_abort_callback_user_data(VALUE self) ruby_whisper_params_get_abort_callback_user_data(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
return rwp->abort_callback_container->user_data; return rwp->abort_callback_container->user_data;
} }
/* /*
@ -1024,74 +891,11 @@ static VALUE
ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->abort_callback_container->user_data = value; rwp->abort_callback_container->user_data = value;
return value; return value;
} }
/*
* call-seq:
* vad = use_vad -> use_vad
*/
static VALUE
ruby_whisper_params_get_vad(VALUE self)
{
BOOL_PARAMS_GETTER(self, vad)
}
static VALUE
ruby_whisper_params_set_vad(VALUE self, VALUE value)
{
BOOL_PARAMS_SETTER(self, vad, value)
}
/*
* call-seq:
* vad_model_path = model_path -> model_path
*/
static VALUE
ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
if (NIL_P(value)) {
rwp->params.vad_model_path = NULL;
return value;
}
VALUE path = ruby_whisper_normalize_model_path(value);
rwp->params.vad_model_path = StringValueCStr(path);
return value;
}
static VALUE
ruby_whisper_params_get_vad_model_path(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->params.vad_model_path == NULL ? Qnil : rb_str_new2(rwp->params.vad_model_path);
}
/*
* call-seq:
* vad_params = params -> params
*/
static VALUE
ruby_whisper_params_set_vad_params(VALUE self, VALUE value)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
rwp->vad_params = value;
return value;
}
static VALUE
ruby_whisper_params_get_vad_params(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
return rwp->vad_params;
}
#define SET_PARAM_IF_SAME(param_name) \ #define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \ if (id == id_ ## param_name) { \
ruby_whisper_params_set_ ## param_name(self, value); \ ruby_whisper_params_set_ ## param_name(self, value); \
@ -1101,6 +905,7 @@ ruby_whisper_params_get_vad_params(VALUE self)
static VALUE static VALUE
ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
{ {
VALUE kw_hash; VALUE kw_hash;
VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef}; VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
VALUE value; VALUE value;
@ -1113,8 +918,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
return self; return self;
} }
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values); rb_get_kwargs(kw_hash, &param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values);
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) { for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
id = param_names[i]; id = param_names[i];
@ -1153,13 +958,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
SET_PARAM_IF_SAME(new_segment_callback_user_data) SET_PARAM_IF_SAME(new_segment_callback_user_data)
SET_PARAM_IF_SAME(progress_callback) SET_PARAM_IF_SAME(progress_callback)
SET_PARAM_IF_SAME(progress_callback_user_data) SET_PARAM_IF_SAME(progress_callback_user_data)
SET_PARAM_IF_SAME(encoder_begin_callback)
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
SET_PARAM_IF_SAME(abort_callback) SET_PARAM_IF_SAME(abort_callback)
SET_PARAM_IF_SAME(abort_callback_user_data) SET_PARAM_IF_SAME(abort_callback_user_data)
SET_PARAM_IF_SAME(vad)
SET_PARAM_IF_SAME(vad_model_path)
SET_PARAM_IF_SAME(vad_params)
} }
} }
@ -1181,10 +981,10 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
static VALUE static VALUE
ruby_whisper_params_on_new_segment(VALUE self) ruby_whisper_params_on_new_segment(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rws;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc(); const VALUE blk = rb_block_proc();
rb_ary_push(rwp->new_segment_callback_container->callbacks, blk); rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
return Qnil; return Qnil;
} }
@ -1201,30 +1001,10 @@ ruby_whisper_params_on_new_segment(VALUE self)
static VALUE static VALUE
ruby_whisper_params_on_progress(VALUE self) ruby_whisper_params_on_progress(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rws;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc(); const VALUE blk = rb_block_proc();
rb_ary_push(rwp->progress_callback_container->callbacks, blk); rb_ary_push(rws->progress_callback_container->callbacks, blk);
return Qnil;
}
/*
* Hook called when the encoder starts.
*
* whisper.on_encoder_begin do
* # ...
* end
*
* call-seq:
* on_encoder_begin { ... }
*/
static VALUE
ruby_whisper_params_on_encoder_begin(VALUE self)
{
ruby_whisper_params *rwp;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
const VALUE blk = rb_block_proc();
rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk);
return Qnil; return Qnil;
} }
@ -1245,10 +1025,10 @@ ruby_whisper_params_on_encoder_begin(VALUE self)
static VALUE static VALUE
ruby_whisper_params_abort_on(VALUE self) ruby_whisper_params_abort_on(VALUE self)
{ {
ruby_whisper_params *rwp; ruby_whisper_params *rws;
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); Data_Get_Struct(self, ruby_whisper_params, rws);
const VALUE blk = rb_block_proc(); const VALUE blk = rb_block_proc();
rb_ary_push(rwp->abort_callback_container->callbacks, blk); rb_ary_push(rws->abort_callback_container->callbacks, blk);
return Qnil; return Qnil;
} }
@ -1288,16 +1068,10 @@ init_ruby_whisper_params(VALUE *mWhisper)
DEFINE_PARAM(new_segment_callback_user_data, 25) DEFINE_PARAM(new_segment_callback_user_data, 25)
DEFINE_PARAM(progress_callback, 26) DEFINE_PARAM(progress_callback, 26)
DEFINE_PARAM(progress_callback_user_data, 27) DEFINE_PARAM(progress_callback_user_data, 27)
DEFINE_PARAM(encoder_begin_callback, 28) DEFINE_PARAM(abort_callback, 28)
DEFINE_PARAM(encoder_begin_callback_user_data, 29) DEFINE_PARAM(abort_callback_user_data, 29)
DEFINE_PARAM(abort_callback, 30)
DEFINE_PARAM(abort_callback_user_data, 31)
DEFINE_PARAM(vad, 32)
DEFINE_PARAM(vad_model_path, 33)
DEFINE_PARAM(vad_params, 34)
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
} }

View File

@ -1,57 +1,28 @@
#include <ruby.h> #include <ruby.h>
#include "ruby_whisper.h" #include "ruby_whisper.h"
#define N_KEY_NAMES 5
static VALUE sym_start_time;
static VALUE sym_end_time;
static VALUE sym_text;
static VALUE sym_no_speech_prob;
static VALUE sym_speaker_turn_next;
static VALUE key_names;
extern const rb_data_type_t ruby_whisper_type;
extern VALUE cSegment; extern VALUE cSegment;
static void static void
rb_whisper_segment_mark(void *p) rb_whisper_segment_mark(ruby_whisper_segment *rws)
{ {
ruby_whisper_segment *rws = (ruby_whisper_segment *)p;
rb_gc_mark(rws->context); rb_gc_mark(rws->context);
} }
static size_t
ruby_whisper_segment_memsize(const void *p)
{
const ruby_whisper_segment *rws = (const ruby_whisper_segment *)p;
size_t size = sizeof(rws);
if (!rws) {
return 0;
}
return size;
}
static const rb_data_type_t ruby_whisper_segment_type = {
"ruby_whisper_segment",
{rb_whisper_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_segment_memsize,},
0, 0,
0
};
VALUE VALUE
ruby_whisper_segment_allocate(VALUE klass) ruby_whisper_segment_allocate(VALUE klass)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
return TypedData_Make_Struct(klass, ruby_whisper_segment, &ruby_whisper_segment_type, rws); rws = ALLOC(ruby_whisper_segment);
return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
} }
VALUE VALUE
rb_whisper_segment_s_new(VALUE context, int index) rb_whisper_segment_initialize(VALUE context, int index)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
const VALUE segment = ruby_whisper_segment_allocate(cSegment); const VALUE segment = ruby_whisper_segment_allocate(cSegment);
TypedData_Get_Struct(segment, ruby_whisper_segment, &ruby_whisper_segment_type, rws); Data_Get_Struct(segment, ruby_whisper_segment, rws);
rws->context = context; rws->context = context;
rws->index = index; rws->index = index;
return segment; return segment;
@ -67,12 +38,12 @@ static VALUE
ruby_whisper_segment_get_start_time(VALUE self) ruby_whisper_segment_get_start_time(VALUE self)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index); const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return LONG2NUM(t0 * 10); return INT2NUM(t0 * 10);
} }
/* /*
@ -85,12 +56,12 @@ static VALUE
ruby_whisper_segment_get_end_time(VALUE self) ruby_whisper_segment_get_end_time(VALUE self)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index); const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return LONG2NUM(t1 * 10); return INT2NUM(t1 * 10);
} }
/* /*
@ -103,9 +74,9 @@ static VALUE
ruby_whisper_segment_get_speaker_turn_next(VALUE self) ruby_whisper_segment_get_speaker_turn_next(VALUE self)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rws->context, ruby_whisper, rw);
return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
} }
@ -117,9 +88,9 @@ static VALUE
ruby_whisper_segment_get_text(VALUE self) ruby_whisper_segment_get_text(VALUE self)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rws->context, ruby_whisper, rw);
const char * text = whisper_full_get_segment_text(rw->context, rws->index); const char * text = whisper_full_get_segment_text(rw->context, rws->index);
return rb_str_new2(text); return rb_str_new2(text);
} }
@ -132,89 +103,21 @@ static VALUE
ruby_whisper_segment_get_no_speech_prob(VALUE self) ruby_whisper_segment_get_no_speech_prob(VALUE self)
{ {
ruby_whisper_segment *rws; ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws); Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw; ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw); Data_Get_Struct(rws->context, ruby_whisper, rw);
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index)); return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
} }
/*
* call-seq:
* deconstruct_keys(keys) -> hash
*
* Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
*
* whisper.each_segment do |segment|
* segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
*
* puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
* end
*/
static VALUE
ruby_whisper_segment_deconstruct_keys(VALUE self, VALUE keys)
{
ruby_whisper_segment *rws;
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
ruby_whisper *rw;
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
VALUE hash = rb_hash_new();
long n_keys;
if (NIL_P(keys)) {
keys = key_names;
n_keys = N_KEY_NAMES;
} else {
n_keys = RARRAY_LEN(keys);
if (n_keys > N_KEY_NAMES) {
return hash;
}
}
for (int i = 0; i < n_keys; i++) {
VALUE key = rb_ary_entry(keys, i);
if (key == sym_start_time) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_start_time(self));
}
if (key == sym_end_time) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_end_time(self));
}
if (key == sym_text) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_text(self));
}
if (key == sym_no_speech_prob) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_no_speech_prob(self));
}
if (key == sym_speaker_turn_next) {
rb_hash_aset(hash, key, ruby_whisper_segment_get_speaker_turn_next(self));
}
}
return hash;
}
void void
init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext) init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
{ {
cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject); cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
sym_start_time = ID2SYM(rb_intern("start_time"));
sym_end_time = ID2SYM(rb_intern("end_time"));
sym_text = ID2SYM(rb_intern("text"));
sym_no_speech_prob = ID2SYM(rb_intern("no_speech_prob"));
sym_speaker_turn_next = ID2SYM(rb_intern("speaker_turn_next"));
key_names = rb_ary_new3(
N_KEY_NAMES,
sym_start_time,
sym_end_time,
sym_text,
sym_no_speech_prob,
sym_speaker_turn_next
);
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_turn_next?", ruby_whisper_segment_get_speaker_turn_next, 0); rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0); rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
rb_define_method(cSegment, "deconstruct_keys", ruby_whisper_segment_deconstruct_keys, 1);
} }

View File

@ -8,15 +8,11 @@
extern "C" { extern "C" {
#endif #endif
extern const rb_data_type_t ruby_whisper_type;
extern const rb_data_type_t ruby_whisper_params_type;
extern ID id_to_s; extern ID id_to_s;
extern ID id_call; extern ID id_call;
extern ID transcribe_option_names[1];
extern void extern void
prepare_transcription(ruby_whisper_params * rwp, VALUE * self); register_callbacks(ruby_whisper_params * rwp, VALUE * self);
/* /*
* transcribe a single file * transcribe a single file
@ -35,16 +31,11 @@ VALUE
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw; ruby_whisper *rw;
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params, kws; VALUE wave_file_path, blk, params;
VALUE opts[1];
rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, &params, &kws, &blk); rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts); Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp);
int n_processors = opts[0] == Qundef ? 1 : NUM2INT(opts[0]);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
if (!rb_respond_to(wave_file_path, id_to_s)) { if (!rb_respond_to(wave_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file"); rb_raise(rb_eRuntimeError, "Expected file path to wave file");
@ -59,24 +50,20 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self; return self;
} }
// Commented out because it is work in progress {
// { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
// bool is_aborted = *(bool*)user_data; bool is_aborted = *(bool*)user_data;
// return !is_aborted; return !is_aborted;
// }; };
// rwp->params.encoder_begin_callback_user_data = &is_aborted; rwp->params.encoder_begin_callback_user_data = &is_aborted;
// }
prepare_transcription(rwp, &self);
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
} }
if (NIL_P(blk)) {
register_callbacks(rwp, &self);
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self; return self;
} }
const int n_segments = whisper_full_n_segments(rw->context); const int n_segments = whisper_full_n_segments(rw->context);
@ -85,7 +72,10 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
const char * text = whisper_full_get_segment_text(rw->context, i); const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text)); output = rb_str_concat(output, rb_str_new2(text));
} }
rb_funcall(blk, id_call, 1, output); VALUE idCall = id_call;
if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output);
}
return self; return self;
} }
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -1,288 +0,0 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define DEFINE_PARAM(param_name, nth) \
id_ ## param_name = rb_intern(#param_name); \
param_names[nth] = id_ ## param_name; \
rb_define_method(cVADParams, #param_name, ruby_whisper_vad_params_get_ ## param_name, 0); \
rb_define_method(cVADParams, #param_name "=", ruby_whisper_vad_params_set_ ## param_name, 1);
#define NUM_PARAMS 6
extern VALUE cVADParams;
static size_t
ruby_whisper_vad_params_memsize(const void *p)
{
const struct ruby_whisper_vad_params *params = p;
size_t size = sizeof(params);
if (!params) {
return 0;
}
return size;
}
static ID param_names[NUM_PARAMS];
static ID id_threshold;
static ID id_min_speech_duration_ms;
static ID id_min_silence_duration_ms;
static ID id_max_speech_duration_s;
static ID id_speech_pad_ms;
static ID id_samples_overlap;
const rb_data_type_t ruby_whisper_vad_params_type = {
"ruby_whisper_vad_params",
{0, 0, ruby_whisper_vad_params_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_vad_params_s_allocate(VALUE klass)
{
ruby_whisper_vad_params *rwvp;
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params = whisper_vad_default_params();
return obj;
}
/*
* Probability threshold to consider as speech.
*
* call-seq:
* threshold = th -> th
*/
static VALUE
ruby_whisper_vad_params_set_threshold(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.threshold = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_threshold(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return DBL2NUM(rwvp->params.threshold);
}
/*
* Min duration for a valid speech segment.
*
* call-seq:
* min_speech_duration_ms = duration_ms -> duration_ms
*/
static VALUE
ruby_whisper_vad_params_set_min_speech_duration_ms(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.min_speech_duration_ms = NUM2INT(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_min_speech_duration_ms(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return INT2NUM(rwvp->params.min_speech_duration_ms);
}
/*
* Min silence duration to consider speech as ended.
*
* call-seq:
* min_silence_duration_ms = duration_ms -> duration_ms
*/
static VALUE
ruby_whisper_vad_params_set_min_silence_duration_ms(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.min_silence_duration_ms = NUM2INT(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_min_silence_duration_ms(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return INT2NUM(rwvp->params.min_silence_duration_ms);
}
/*
* Max duration of a speech segment before forcing a new segment.
*
* call-seq:
* max_speech_duration_s = duration_s -> duration_s
*/
static VALUE
ruby_whisper_vad_params_set_max_speech_duration_s(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.max_speech_duration_s = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_max_speech_duration_s(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return DBL2NUM(rwvp->params.max_speech_duration_s);
}
/*
* Padding added before and after speech segments.
*
* call-seq:
* speech_pad_ms = pad_ms -> pad_ms
*/
static VALUE
ruby_whisper_vad_params_set_speech_pad_ms(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.speech_pad_ms = NUM2INT(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_speech_pad_ms(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return INT2NUM(rwvp->params.speech_pad_ms);
}
/*
* Overlap in seconds when copying audio samples from speech segment.
*
* call-seq:
* samples_overlap = overlap -> overlap
*/
static VALUE
ruby_whisper_vad_params_set_samples_overlap(VALUE self, VALUE value)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rwvp->params.samples_overlap = RFLOAT_VALUE(value);
return value;
}
static VALUE
ruby_whisper_vad_params_get_samples_overlap(VALUE self)
{
ruby_whisper_vad_params *rwvp;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
return DBL2NUM(rwvp->params.samples_overlap);
}
static VALUE
ruby_whisper_vad_params_equal(VALUE self, VALUE other)
{
ruby_whisper_vad_params *rwvp1;
ruby_whisper_vad_params *rwvp2;
if (self == other) {
return Qtrue;
}
if (!rb_obj_is_kind_of(other, cVADParams)) {
return Qfalse;
}
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp1);
TypedData_Get_Struct(other, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp2);
if (rwvp1->params.threshold != rwvp2->params.threshold) {
return Qfalse;
}
if (rwvp1->params.min_speech_duration_ms != rwvp2->params.min_speech_duration_ms) {
return Qfalse;
}
if (rwvp1->params.min_silence_duration_ms != rwvp2->params.min_silence_duration_ms) {
return Qfalse;
}
if (rwvp1->params.max_speech_duration_s != rwvp2->params.max_speech_duration_s) {
return Qfalse;
}
if (rwvp1->params.speech_pad_ms != rwvp2->params.speech_pad_ms) {
return Qfalse;
}
if (rwvp1->params.samples_overlap != rwvp2->params.samples_overlap) {
return Qfalse;
}
return Qtrue;
}
#define SET_PARAM_IF_SAME(param_name) \
if (id == id_ ## param_name) { \
ruby_whisper_vad_params_set_ ## param_name(self, value); \
continue; \
}
VALUE
ruby_whisper_vad_params_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE kw_hash;
VALUE values[NUM_PARAMS] = {Qundef};
VALUE value;
ruby_whisper_vad_params *rwvp;
ID id;
int i;
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
if (NIL_P(kw_hash)) {
return self;
}
rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
for (i = 0; i < NUM_PARAMS; i++) {
id = param_names[i];
value = values[i];
if (value == Qundef) {
continue;
}
SET_PARAM_IF_SAME(threshold)
SET_PARAM_IF_SAME(min_speech_duration_ms)
SET_PARAM_IF_SAME(min_silence_duration_ms)
SET_PARAM_IF_SAME(max_speech_duration_s)
SET_PARAM_IF_SAME(speech_pad_ms)
SET_PARAM_IF_SAME(samples_overlap)
}
return self;
}
#undef SET_PARAM_IF_SAME
void
init_ruby_whisper_vad_params(VALUE *mVAD)
{
cVADParams = rb_define_class_under(*mVAD, "Params", rb_cObject);
rb_define_alloc_func(cVADParams, ruby_whisper_vad_params_s_allocate);
rb_define_method(cVADParams, "initialize", ruby_whisper_vad_params_initialize, -1);
DEFINE_PARAM(threshold, 0)
DEFINE_PARAM(min_speech_duration_ms, 1)
DEFINE_PARAM(min_silence_duration_ms, 2)
DEFINE_PARAM(max_speech_duration_s, 3)
DEFINE_PARAM(speech_pad_ms, 4)
DEFINE_PARAM(samples_overlap, 5)
rb_define_method(cVADParams, "==", ruby_whisper_vad_params_equal, 1);
}
#undef DEFINE_PARAM
#undef NUM_PARAMS

View File

@ -1,8 +0,0 @@
set(GRAPHVIZ_EXECUTABLES FALSE)
set(GRAPHVIZ_STATIC_LIBS TRUE)
set(GRAPHVIZ_SHARED_LIBS FALSE)
set(GRAPHVIZ_MODULE_LIBS FALSE)
set(GRAPHVIZ_INTERFACE_LIBS FALSE)
set(GRAPHVIZ_OBJECT_LIBS FALSE)
set(GRAPHVIZ_UNKNOWN_LIBS FALSE)
set(GRAPHVIZ_GENERATE_DEPENDERS FALSE)

View File

@ -1,40 +1,6 @@
require "pathname" require "yaml"
root = Pathname("..")/".." sources = `git ls-files -z ../..`.split("\x0")
ignored_dirs = %w[ paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
.devops paths.delete "bindings/ruby/**"
.github EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources
ci
examples/wchess/wchess.wasm
examples/whisper.android
examples/whisper.android.java
examples/whisper.objc
examples/whisper.swiftui
grammars
models
samples
scripts
].collect {|dir| root/dir}
ignored_files = %w[
AUTHORS
Makefile
README.md
README_sycl.md
.gitignore
.gitmodules
.dockerignore
whisper.nvim
twitch.sh
yt-wsp.sh
close-issue.yml
]
EXTSOURCES =
`git ls-files -z #{root}`.split("\x0")
.collect {|file| Pathname(file)}
.reject {|file|
ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} ||
ignored_files.include?(file.basename.to_path) ||
(file.descend.to_a[1] != root && file.descend.to_a[1] != Pathname("..")/"javascript")
}
.collect(&:to_path)

View File

@ -1,15 +0,0 @@
module Whisper
class Context
def to_srt
each_segment.with_index.reduce("") {|srt, (segment, index)|
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
}
end
def to_webvtt
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
}
end
end
end

View File

@ -34,7 +34,7 @@ module Whisper
when /darwin/ when /darwin/
Pathname(Dir.home)/"Library/Caches" Pathname(Dir.home)/"Library/Caches"
else else
ENV.key?("XDG_CACHE_HOME") ? Pathname(ENV["XDG_CACHE_HOME"]) : Pathname(Dir.home)/".cache" ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
end end
base/"whisper.cpp" base/"whisper.cpp"
end end
@ -55,8 +55,6 @@ module Whisper
when Net::HTTPNotModified when Net::HTTPNotModified
# noop # noop
when Net::HTTPOK when Net::HTTPOK
return if !response.key?("last-modified") && cache_path.exist?
download response download response
when Net::HTTPRedirection when Net::HTTPRedirection
request URI(response["location"]), headers request URI(response["location"]), headers
@ -130,44 +128,6 @@ module Whisper
end end
end end
class ZipURI < URI
def cache
zip_path = super
dest = unzipped_path
return if dest.exist? && dest.mtime >= zip_path.mtime
escaping dest do
system "unzip", "-q", "-d", zip_path.dirname.to_path, zip_path.to_path, exception: true
end
zip_path
end
def clear_cache
super
unzipped_path.rmtree if unzipped_path.exist?
end
private
def unzipped_path
cache_path.sub_ext("")
end
def escaping(path)
escaped = Pathname("#{path}.removing")
if path.exist?
escaped.rmtree if escaped.exist?
path.rename escaped
end
yield
ensure
if path.exist?
escaped.rmtree if escaped.exist?
else
escaped.rename path if escaped.exist?
end
end
end
@pre_converted_models = %w[ @pre_converted_models = %w[
tiny tiny
tiny.en tiny.en
@ -203,31 +163,8 @@ module Whisper
models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
} }
%w[
silero-v5.1.2
].each do |name|
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
end
@coreml_compiled_models = %w[
tiny
tiny.en
base
base.en
small
small.en
medium
medium.en
large-v1
large-v2
large-v3
large-v3-turbo
].each_with_object({}) do |name, models|
models[@pre_converted_models[name]] = ZipURI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}-encoder.mlmodelc.zip")
end
class << self class << self
attr_reader :pre_converted_models, :coreml_compiled_models attr_reader :pre_converted_models
end end
end end
end end

View File

@ -1,58 +0,0 @@
module Whisper
class Segment
SRT_ESCAPES = {
"&" => "&amp;",
"<" => "&lt;",
">" => "&gt;",
}
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
def to_srt_cue
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
end
def to_webvtt_cue
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
end
private
def time_to_a(time)
sec, decimal_part = time.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
[hour, min, sec, decimal_part]
end
def srt_time(time)
"%02d:%02d:%02d,%03d" % time_to_a(time)
end
def srt_start_time
srt_time(start_time)
end
def srt_end_time
srt_time(end_time)
end
def srt_text
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
end
def webvtt_time(time)
"%02d:%02d:%02d.%03d" % time_to_a(time)
end
def webvtt_start_time
webvtt_time(start_time)
end
def webvtt_end_time
webvtt_time(end_time)
end
alias webvtt_text srt_text
end
end

View File

@ -7,10 +7,8 @@ module Whisper
type log_callback = ^(Integer level, String message, Object user_data) -> void type log_callback = ^(Integer level, String message, Object user_data) -> void
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
VERSION: String
LOG_LEVEL_NONE: Integer LOG_LEVEL_NONE: Integer
LOG_LEVEL_INFO: Integer LOG_LEVEL_INFO: Integer
LOG_LEVEL_WARN: Integer LOG_LEVEL_WARN: Integer
@ -23,23 +21,11 @@ module Whisper
def self.lang_str: (Integer id) -> String def self.lang_str: (Integer id) -> String
def self.lang_str_full: (Integer id) -> String def self.lang_str_full: (Integer id) -> String
def self.log_set: (log_callback, Object? user_data) -> log_callback def self.log_set: (log_callback, Object? user_data) -> log_callback
def self.system_info_str: () -> String
class Context class Context
def self.new: (String | path | ::URI::HTTP) -> instance def self.new: (string | _ToPath | ::URI::HTTP) -> instance
def transcribe: (string, Params) -> self
# transcribe a single file | (string, Params) { (String) -> void } -> self
# can emit to a block results
#
# params = Whisper::Params.new
# params.duration = 60_000
# whisper.transcribe "path/to/audio.wav", params do |text|
# puts text
# end
#
def transcribe: (string, Params, ?n_processors: Integer) -> self
| (string, Params, ?n_processors: Integer) { (String) -> void } -> self
def model_n_vocab: () -> Integer def model_n_vocab: () -> Integer
def model_n_audio_ctx: () -> Integer def model_n_audio_ctx: () -> Integer
def model_n_audio_state: () -> Integer def model_n_audio_state: () -> Integer
@ -48,78 +34,22 @@ module Whisper
def model_n_mels: () -> Integer def model_n_mels: () -> Integer
def model_ftype: () -> Integer def model_ftype: () -> Integer
def model_type: () -> String def model_type: () -> String
# Yields each Whisper::Segment:
#
# whisper.transcribe("path/to/audio.wav", params)
# whisper.each_segment do |segment|
# puts segment.text
# end
#
# Returns an Enumerator if no block given:
#
# whisper.transcribe("path/to/audio.wav", params)
# enum = whisper.each_segment
# enum.to_a # => [#<Whisper::Segment>, ...]
#
def each_segment: { (Segment) -> void } -> void def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment] | () -> Enumerator[Segment]
def model: () -> Model def model: () -> Model
def full_get_segment: (Integer nth) -> Segment def full_get_segment: (Integer nth) -> Segment
def full_n_segments: () -> Integer def full_n_segments: () -> Integer
# Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
#
def full_lang_id: () -> Integer def full_lang_id: () -> Integer
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t0(3) # => 1668 (16680 ms)
#
def full_get_segment_t0: (Integer) -> Integer def full_get_segment_t0: (Integer) -> Integer
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
#
# full_get_segment_t1(3) # => 1668 (16680 ms)
#
def full_get_segment_t1: (Integer) -> Integer def full_get_segment_t1: (Integer) -> Integer
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
#
# full_get_segment_speacker_turn_next(3) # => true
#
def full_get_segment_speaker_turn_next: (Integer) -> (true | false) def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
# Text of a segment indexed by +segment_index+.
#
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
#
def full_get_segment_text: (Integer) -> String def full_get_segment_text: (Integer) -> String
def full_get_segment_no_speech_prob: (Integer) -> Float def full_get_segment_no_speech_prob: (Integer) -> Float
# Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
# Not thread safe for same context
# Uses the specified decoding strategy to obtain the text.
#
# The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
#
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self
# Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
# Result is stored in the default state of the context
# Not thread safe if executed in parallel on the same context.
# It seems this approach can offer some speedup in some cases.
# However, the transcription accuracy can be worse at the beginning and end of each chunk.
#
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
def to_srt: () -> String
def to_webvtt: () -> String
end end
class Params class Params
@ -152,246 +82,76 @@ module Whisper
?new_segment_callback_user_data: Object, ?new_segment_callback_user_data: Object,
?progress_callback: progress_callback, ?progress_callback: progress_callback,
?progress_callback_user_data: Object, ?progress_callback_user_data: Object,
?encoder_begin_callback: encoder_begin_callback,
?encoder_begin_callback_user_data: Object,
?abort_callback: abort_callback, ?abort_callback: abort_callback,
?abort_callback_user_data: Object, ?abort_callback_user_data: Object
?vad: boolish,
?vad_model_path: path | URI,
?vad_params: Whisper::VAD::Params
) -> instance ) -> instance
# params.language = "auto" | "en", etc...
#
def language=: (String) -> String # TODO: Enumerate lang names def language=: (String) -> String # TODO: Enumerate lang names
def language: () -> String def language: () -> String
def translate=: (boolish) -> boolish def translate=: (boolish) -> boolish
def translate: () -> (true | false) def translate: () -> (true | false)
def no_context=: (boolish) -> boolish def no_context=: (boolish) -> boolish
# If true, does not use past transcription (if any) as initial prompt for the decoder.
#
def no_context: () -> (true | false) def no_context: () -> (true | false)
def single_segment=: (boolish) -> boolish def single_segment=: (boolish) -> boolish
# If true, forces single segment output (useful for streaming).
#
def single_segment: () -> (true | false) def single_segment: () -> (true | false)
def print_special=: (boolish) -> boolish def print_special=: (boolish) -> boolish
# If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.).
#
def print_special: () -> (true | false) def print_special: () -> (true | false)
def print_progress=: (boolish) -> boolish def print_progress=: (boolish) -> boolish
# If true, prints progress information.
#
def print_progress: () -> (true | false) def print_progress: () -> (true | false)
def print_realtime=: (boolish) -> boolish def print_realtime=: (boolish) -> boolish
# If true, prints results from within whisper.cpp. (avoid it, use callback instead)
#
def print_realtime: () -> (true | false) def print_realtime: () -> (true | false)
# If true, prints timestamps for each text segment when printing realtime.
#
def print_timestamps=: (boolish) -> boolish def print_timestamps=: (boolish) -> boolish
def print_timestamps: () -> (true | false) def print_timestamps: () -> (true | false)
def suppress_blank=: (boolish) -> boolish def suppress_blank=: (boolish) -> boolish
# If true, suppresses blank outputs.
#
def suppress_blank: () -> (true | false) def suppress_blank: () -> (true | false)
def suppress_nst=: (boolish) -> boolish def suppress_nst=: (boolish) -> boolish
# If true, suppresses non-speech-tokens.
#
def suppress_nst: () -> (true | false) def suppress_nst: () -> (true | false)
def token_timestamps=: (boolish) -> boolish def token_timestamps=: (boolish) -> boolish
# If true, enables token-level timestamps.
#
def token_timestamps: () -> (true | false) def token_timestamps: () -> (true | false)
def split_on_word=: (boolish) -> boolish def split_on_word=: (boolish) -> boolish
# If true, split on word rather than on token (when used with max_len).
#
def split_on_word: () -> (true | false) def split_on_word: () -> (true | false)
def initial_prompt=: (_ToS) -> _ToS def initial_prompt=: (_ToS) -> _ToS
# Tokens to provide to the whisper decoder as initial prompt
# these are prepended to any existing text context from a previous call
# use whisper_tokenize() to convert text to tokens.
# Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224).
#
def initial_prompt: () -> (String | nil) def initial_prompt: () -> (String | nil)
def diarize=: (boolish) -> boolish def diarize=: (boolish) -> boolish
# If true, enables diarization.
#
def diarize: () -> (true | false) def diarize: () -> (true | false)
def offset=: (Integer) -> Integer def offset=: (Integer) -> Integer
# Start offset in ms.
#
def offset: () -> Integer def offset: () -> Integer
def duration=: (Integer) -> Integer def duration=: (Integer) -> Integer
# Audio duration to process in ms.
#
def duration: () -> Integer def duration: () -> Integer
def max_text_tokens=: (Integer) -> Integer def max_text_tokens=: (Integer) -> Integer
# Max tokens to use from past text as prompt for the decoder.
#
def max_text_tokens: () -> Integer def max_text_tokens: () -> Integer
def temperature=: (Float) -> Float def temperature=: (Float) -> Float
def temperature: () -> Float def temperature: () -> Float
def max_initial_ts=: (Float) -> Float def max_initial_ts=: (Float) -> Float
# See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
#
def max_initial_ts: () -> Float def max_initial_ts: () -> Float
def length_penalty=: (Float) -> Float def length_penalty=: (Float) -> Float
def length_penalty: () -> Float def length_penalty: () -> Float
def temperature_inc=: (Float) -> Float def temperature_inc=: (Float) -> Float
def temperature_inc: () -> Float def temperature_inc: () -> Float
def entropy_thold=: (Float) -> Float def entropy_thold=: (Float) -> Float
# Similar to OpenAI's "compression_ratio_threshold"
#
def entropy_thold: () -> Float def entropy_thold: () -> Float
def logprob_thold=: (Float) -> Float def logprob_thold=: (Float) -> Float
def logprob_thold: () -> Float def logprob_thold: () -> Float
def no_speech_thold=: (Float) -> Float def no_speech_thold=: (Float) -> Float
def no_speech_thold: () -> Float def no_speech_thold: () -> Float
# Sets new segment callback, called for every newly generated text segment.
#
# params.new_segment_callback = ->(context, _, n_new, user_data) {
# # ...
# }
#
def new_segment_callback=: (new_segment_callback) -> new_segment_callback def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback: () -> (new_segment_callback | nil) def new_segment_callback: () -> (new_segment_callback | nil)
# Sets user data passed to the last argument of new segment callback.
#
def new_segment_callback_user_data=: (Object) -> Object def new_segment_callback_user_data=: (Object) -> Object
def new_segment_callback_user_data: () -> Object def new_segment_callback_user_data: () -> Object
# Sets progress callback, called on each progress update.
#
# params.new_segment_callback = ->(context, _, progress, user_data) {
# # ...
# }
#
# +progress+ is an Integer between 0 and 100.
#
def progress_callback=: (progress_callback) -> progress_callback def progress_callback=: (progress_callback) -> progress_callback
def progress_callback: () -> (progress_callback | nil) def progress_callback: () -> (progress_callback | nil)
# Sets user data passed to the last argument of progress callback.
#
def progress_callback_user_data=: (Object) -> Object def progress_callback_user_data=: (Object) -> Object
def progress_callback_user_data: () -> Object def progress_callback_user_data: () -> Object
# Sets encoder begin callback, called when the encoder starts.
#
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
def encoder_begin_callback: () -> (encoder_begin_callback | nil)
# Sets user data passed to the last argument of encoder begin callback.
#
def encoder_begin_callback_user_data=: (Object) -> Object
def encoder_begin_callback_user_data: () -> Object
# Sets abort callback, called to check if the process should be aborted.
#
# params.abort_callback = ->(user_data) {
# # ...
# }
#
#
def abort_callback=: (abort_callback) -> abort_callback def abort_callback=: (abort_callback) -> abort_callback
def abort_callback: () -> (abort_callback | nil) def abort_callback: () -> (abort_callback | nil)
# Sets user data passed to the last argument of abort callback.
#
def abort_callback_user_data=: (Object) -> Object def abort_callback_user_data=: (Object) -> Object
def abort_callback_user_data: () -> Object def abort_callback_user_data: () -> Object
# Enable VAD
#
def vad=: (boolish) -> boolish
def vad: () -> (true | false)
# Path to the VAD model
def vad_model_path=: (path | URI | nil) -> (path | URI | nil)
def vad_model_path: () -> (String | nil)
def vad_params=: (Whisper::VAD::Params) -> Whisper::VAD::Params
def vad_params: () -> (Whisper::VAD::Params)
# Hook called on new segment. Yields each Whisper::Segment.
#
# whisper.on_new_segment do |segment|
# # ...
# end
#
def on_new_segment: { (Segment) -> void } -> void def on_new_segment: { (Segment) -> void } -> void
# Hook called on progress update. Yields each progress Integer between 0 and 100.
#
def on_progress: { (Integer progress) -> void } -> void def on_progress: { (Integer progress) -> void } -> void
# Hook called on encoder starts.
#
def on_encoder_begin: { () -> void } -> void
# Call block to determine whether abort or not. Return +true+ when you want to abort.
#
# params.abort_on do
# if some_condition
# true # abort
# else
# false # continue
# end
# end
#
def abort_on: { (Object user_data) -> boolish } -> void def abort_on: { (Object user_data) -> boolish } -> void
end end
class Model class Model
def self.pre_converted_models: () -> Hash[String, Model::URI] def self.pre_converted_models: () -> Hash[String, Model::URI]
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
def self.new: () -> instance def self.new: () -> instance
def n_vocab: () -> Integer def n_vocab: () -> Integer
def n_audio_ctx: () -> Integer def n_audio_ctx: () -> Integer
@ -407,99 +167,18 @@ module Whisper
def type: () -> String def type: () -> String
class URI class URI
def self.new: (string | ::URI::HTTP) -> instance def self.new: (string | ::URI::HTTP) -> self
def to_path: -> String def to_path: -> String
def clear_cache: -> void def clear_cache: -> void
end end
class ZipURI < URI
def cache: () -> Pathname
def clear_cache: () -> void
end
end end
class Segment class Segment
type deconstructed_keys = {
start_time: (Integer | nil),
end_time: (Integer | nil),
text: (String | nil),
no_speech_prob: (Float | nil),
speaker_turn_next: (true | false | nil)
}
# Start time in milliseconds.
#
def start_time: () -> Integer def start_time: () -> Integer
# End time in milliseconds.
#
def end_time: () -> Integer def end_time: () -> Integer
def speaker_next_turn?: () -> (true | false)
# Whether the next segment is predicted as a speaker turn.
def speaker_turn_next?: () -> (true | false)
def text: () -> String def text: () -> String
def no_speech_prob: () -> Float def no_speech_prob: () -> Float
def to_srt_cue: () -> String
def to_webvtt_cue: () -> String
# Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
#
# whisper.each_segment do |segment|
# segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
#
# puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
# end
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next] | nil) -> deconstructed_keys
end
module VAD
class Params
def self.new: (
?threshold: Float,
?min_speech_duration_ms: Integer,
?min_silence_duration_ms: Integer,
?max_speech_duration_s: Float,
?speech_pad_ms: Integer,
?samples_overlap: Float
) -> instance
# Probability threshold to consider as speech.
#
def threshold=: (Float) -> Float
def threshold: () -> Float
# Min duration for a valid speech segment.
#
def min_speech_duration_ms=: (Integer) -> Integer
def min_speech_duration_ms: () -> Integer
# Min silence duration to consider speech as ended.
#
def min_silence_duration_ms=: (Integer) -> Integer
def min_silence_duration_ms: () -> Integer
# Max duration of a speech segment before forcing a new segment.
def max_speech_duration_s=: (Float) -> Float
def max_speech_duration_s: () -> Float
# Padding added before and after speech segments.
#
def speech_pad_ms=: (Integer) -> Integer
def speech_pad_ms: () -> Integer
# Overlap in seconds when copying audio samples from speech segment.
#
def samples_overlap=: (Float) -> Float
def samples_overlap: () -> Float
def ==: (Params) -> (true | false)
end
end end
class Error < StandardError class Error < StandardError

View File

@ -1,51 +0,0 @@
require_relative "helper"
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < TestBase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert file.size > 0
assert_path_exist file.to_path
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_install
gemspec = Gem::Specification.load("whispercpp.gemspec")
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true
assert_installed dir, gemspec.version
end
end
def test_install_with_coreml
omit_unless RUBY_PLATFORM.match?(/darwin/) do
gemspec = Gem::Specification.load("whispercpp.gemspec")
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true
assert_installed dir, gemspec.version
libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib")
assert_nothing_raised do
system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true
end
assert_match(/COREML = 1/, `ruby -I #{libdir.shellescape} -r whisper -e 'puts Whisper.system_info_str'`)
end
end
end
private
def assert_installed(dir, version)
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
end
end
end

View File

@ -1,146 +0,0 @@
require_relative "helper"
class TestSegment < TestBase
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
end
end
def test_enumerator
enum = whisper.each_segment
assert_instance_of Enumerator, enum
enum.to_a.each_with_index do |segment, index|
assert_instance_of Whisper::Segment, segment
assert_kind_of Integer, index
end
end
def test_start_time
i = 0
whisper.each_segment do |segment|
assert_equal 0, segment.start_time if i == 0
i += 1
end
end
def test_end_time
i = 0
whisper.each_segment do |segment|
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
i += 1
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
index = 0
params.on_new_segment do |segment|
assert_instance_of Whisper::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
end
index += 1
end
whisper.transcribe(AUDIO, params)
assert_equal 0, seg.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
end
def test_on_new_segment_twice
params = Whisper::Params.new
seg = nil
params.on_new_segment do |segment|
seg = segment
return
end
params.on_new_segment do |segment|
assert_same seg, segment
return
end
whisper.transcribe(AUDIO, params)
end
def test_transcription_after_segment_retrieved
params = Whisper::Params.new
segment = whisper.each_segment.first
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
whisper.transcribe(AUDIO, Whisper::Params.new(offset: 5000))
assert_not_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
assert_match(/what you can do for your country/i, segment.text)
end
def test_pattern_matching
segment = whisper.each_segment.first
segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
assert_equal segment.start_time, start_time
assert_equal segment.end_time, end_time
assert_equal segment.text, text
assert_equal segment.no_speech_prob, no_speech_prob
assert_equal segment.speaker_turn_next?, speaker_turn_next
end
def test_pattern_matching_partial
segment = whisper.each_segment.first
segment => {start_time:, end_time:, text:}
assert_equal segment.start_time, start_time
assert_equal segment.end_time, end_time
assert_equal segment.text, text
end
def test_deconstruct_keys
segment = whisper.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text,
no_speech_prob: segment.no_speech_prob,
speaker_turn_next: segment.speaker_turn_next?
}
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next])
end
def test_deconstruct_keys_non_existent
omit "Undefined behavior"
segment = whisper.each_segment.first
assert_equal({}, segment.deconstruct_keys([:non_existent]))
end
def test_deconstruct_keys_too_many_keys
omit "Undefined behavior"
segment = whisper.each_segment.first
assert_equal({}, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next, :extra_key]))
end
def test_deconstruct_keys_includes_non_existent_keys_not_too_many
omit "Undefined behavior"
segment = whisper.each_segment.first
expected = {
start_time: segment.start_time,
end_time: segment.end_time,
text: segment.text,
no_speech_prob: segment.no_speech_prob
}
assert_equal(expected, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :non_existent]))
end
end

View File

@ -1,19 +0,0 @@
require_relative "helper"
class TestVAD < TestBase
def setup
@whisper = Whisper::Context.new("base.en")
vad_params = Whisper::VAD::Params.new
@params = Whisper::Params.new(
vad: true,
vad_model_path: "silero-v5.1.2",
vad_params:
)
end
def test_transcribe
@whisper.transcribe(TestBase::AUDIO, @params) do |text|
assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
end
end
end

View File

@ -1,103 +0,0 @@
require_relative "helper"
class TestVADParams < TestBase
PARAM_NAMES = [
:threshold,
:min_speech_duration_ms,
:min_silence_duration_ms,
:max_speech_duration_s,
:speech_pad_ms,
:samples_overlap
]
def setup
@params = Whisper::VAD::Params.new
end
def test_new
params = Whisper::VAD::Params.new
assert_kind_of Whisper::VAD::Params, params
end
def test_threshold
assert_in_delta @params.threshold, 0.5
@params.threshold = 0.7
assert_in_delta @params.threshold, 0.7
end
def test_min_speech_duration
pend
end
def test_min_speech_duration_ms
assert_equal 250, @params.min_speech_duration_ms
@params.min_speech_duration_ms = 500
assert_equal 500, @params.min_speech_duration_ms
end
def test_min_silence_duration_ms
assert_equal 100, @params.min_silence_duration_ms
@params.min_silence_duration_ms = 200
assert_equal 200, @params.min_silence_duration_ms
end
def test_max_speech_duration
pend
end
def test_max_speech_duration_s
assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX
@params.max_speech_duration_s = 60.0
assert_equal 60.0, @params.max_speech_duration_s
end
def test_speech_pad_ms
assert_equal 30, @params.speech_pad_ms
@params.speech_pad_ms = 50
assert_equal 50, @params.speech_pad_ms
end
def test_samples_overlap
assert_in_delta @params.samples_overlap, 0.1
@params.samples_overlap = 0.5
assert_in_delta @params.samples_overlap, 0.5
end
def test_equal
assert_equal @params, Whisper::VAD::Params.new
end
def test_new_with_kw_args
params = Whisper::VAD::Params.new(threshold: 0.7)
assert_in_delta params.threshold, 0.7
assert_equal 250, params.min_speech_duration_ms
end
def test_new_with_kw_args_non_existent
assert_raise ArgumentError do
Whisper::VAD::Params.new(non_existent: "value")
end
end
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
def test_new_with_kw_args_default_values(param)
default_value = @params.send(param)
value = default_value + 1
params = Whisper::VAD::Params.new(param => value)
if Float === value
assert_in_delta value, params.send(param)
else
assert_equal value, params.send(param)
end
PARAM_NAMES.reject {|name| name == param}.each do |name|
expected = @params.send(name)
actual = params.send(name)
if Float === expected
assert_in_delta expected, actual
else
assert_equal expected, actual
end
end
end
end

View File

@ -3,12 +3,12 @@ require "whisper"
require_relative "jfk_reader/jfk_reader" require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase class TestBase < Test::Unit::TestCase
AUDIO = File.join(__dir__, "fixtures", "jfk.wav") AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
class << self class << self
def whisper attr_reader :whisper
return @whisper if @whisper
def startup
@whisper = Whisper::Context.new("base.en") @whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new params = Whisper::Params.new
params.print_timestamps = false params.print_timestamps = false

View File

@ -111,48 +111,6 @@ class TestCallback < TestBase
assert_equal 100, last assert_equal 100, last
end end
def test_encoder_begin_callback
i = 0
@params.encoder_begin_callback = ->(context, state, user_data) {
i += 1
}
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_encoder_begin_callback_abort
logs = []
Whisper.log_set -> (level, buffer, user_data) {
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
}, logs
@params.encoder_begin_callback = ->(context, state, user_data) {
return false
}
@whisper.transcribe(@audio, @params)
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
Whisper.log_set ->(level, buffer, user_data) {}, nil
end
def test_encoder_begin_callback_user_data
udata = Object.new
@params.encoder_begin_callback_user_data = udata
yielded = nil
@params.encoder_begin_callback = ->(context, state, user_data) {
yielded = user_data
}
@whisper.transcribe(@audio, @params)
assert_same udata, yielded
end
def test_on_encoder_begin
i = 0
@params.on_encoder_begin do
i += 1
end
@whisper.transcribe(@audio, @params)
assert i > 0
end
def test_abort_callback def test_abort_callback
i = 0 i = 0
@params.abort_callback = ->(user_data) { @params.abort_callback = ->(user_data) {

View File

@ -106,13 +106,4 @@ class TestModel < TestBase
assert_equal 1, model.ftype assert_equal 1, model.ftype
assert_equal "base", model.type assert_equal "base", model.type
end end
def test_coreml_model_auto_download
uri = Whisper::Model.coreml_compiled_models[Whisper::Model.pre_converted_models["tiny"]]
model_path = Pathname(uri.to_path).sub_ext("")
model_path.rmtree if model_path.exist?
uri.cache
assert_path_exist model_path
end
end end

View File

@ -0,0 +1,31 @@
require_relative "helper"
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < TestBase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert file.size > 0
assert_path_exist file.to_path
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_install
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
filename = match_data[1]
version = match_data[2]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", basename)
end
end
end
end

View File

@ -32,9 +32,6 @@ class TestParams < TestBase
:progress_callback_user_data, :progress_callback_user_data,
:abort_callback, :abort_callback,
:abort_callback_user_data, :abort_callback_user_data,
:vad,
:vad_model_path,
:vad_params,
] ]
def setup def setup
@ -194,50 +191,6 @@ class TestParams < TestBase
assert_in_delta 0.2, @params.no_speech_thold assert_in_delta 0.2, @params.no_speech_thold
end end
def test_vad
assert_false @params.vad
@params.vad = true
assert_true @params.vad
end
def test_vad_model_path
assert_nil @params.vad_model_path
@params.vad_model_path = "silero-v5.1.2"
assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path
end
def test_vad_model_path_with_nil
@params.vad_model_path = "silero-v5.1.2"
@params.vad_model_path = nil
assert_nil @params.vad_model_path
end
def test_vad_model_path_with_invalid
assert_raise TypeError do
@params.vad_model_path = Object.new
end
end
def test_vad_model_path_with_URI_string
@params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin"
assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
end
def test_vad_model_path_with_URI
@params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin")
assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
end
def test_vad_params
assert_kind_of Whisper::VAD::Params, @params.vad_params
default_params = @params.vad_params
assert_same default_params, @params.vad_params
assert_equal 0.5, default_params.threshold
new_params = Whisper::VAD::Params.new
@params.vad_params = new_params
assert_same new_params, @params.vad_params
end
def test_new_with_kw_args def test_new_with_kw_args
params = Whisper::Params.new(language: "es") params = Whisper::Params.new(language: "es")
assert_equal "es", params.language assert_equal "es", params.language
@ -272,10 +225,6 @@ class TestParams < TestBase
proc {} proc {}
in [/_user_data\Z/, *] in [/_user_data\Z/, *]
Object.new Object.new
in [:vad_model_path, *]
Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
in [:vad_params, *]
Whisper::VAD::Params.new
end end
params = Whisper::Params.new(param => value) params = Whisper::Params.new(param => value)
if Float === value if Float === value

View File

@ -0,0 +1,74 @@
require_relative "helper"
class TestSegment < TestBase
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
end
end
def test_enumerator
enum = whisper.each_segment
assert_instance_of Enumerator, enum
enum.to_a.each_with_index do |segment, index|
assert_instance_of Whisper::Segment, segment
assert_kind_of Integer, index
end
end
def test_start_time
i = 0
whisper.each_segment do |segment|
assert_equal 0, segment.start_time if i == 0
i += 1
end
end
def test_end_time
i = 0
whisper.each_segment do |segment|
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
i += 1
end
end
def test_no_speech_prob
no_speech_prob = nil
whisper.each_segment do |segment|
no_speech_prob = segment.no_speech_prob
end
assert no_speech_prob > 0.0
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
index = 0
params.on_new_segment do |segment|
assert_instance_of Whisper::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
end
index += 1
end
whisper.transcribe(AUDIO, params)
assert_equal 0, seg.start_time
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
end
def test_on_new_segment_twice
params = Whisper::Params.new
seg = nil
params.on_new_segment do |segment|
seg = segment
return
end
params.on_new_segment do |segment|
assert_same seg, segment
return
end
whisper.transcribe(AUDIO, params)
end
end

View File

@ -20,24 +20,6 @@ class TestWhisper < TestBase
} }
end end
def test_transcribe_non_parallel
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
@whisper.transcribe(AUDIO, params, n_processors: 1) {|text|
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
}
end
def test_transcribe_n_processors
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
@whisper.transcribe(AUDIO, params, n_processors: 4) {|text|
assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
}
end
sub_test_case "After transcription" do sub_test_case "After transcription" do
def test_full_n_segments def test_full_n_segments
assert_equal 1, whisper.full_n_segments assert_equal 1, whisper.full_n_segments
@ -112,14 +94,6 @@ class TestWhisper < TestBase
end end
end end
def test_system_info_str
assert_match(/\AWHISPER : COREML = \d | OPENVINO = \d |/, Whisper.system_info_str)
end
def test_version
assert_kind_of String, Whisper::VERSION
end
def test_log_set def test_log_set
user_data = Object.new user_data = Object.new
logs = [] logs = []
@ -249,48 +223,4 @@ class TestWhisper < TestBase
assert_match(/for your country/i, text) assert_match(/for your country/i, text)
end end
end end
def test_to_srt
whisper = Whisper::Context.new("base.en")
whisper.transcribe AUDIO, @params
lines = whisper.to_srt.lines
assert_match(/\A\d+\n/, lines[0])
assert_match(/\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}\n/, lines[1])
assert_match(/ask not what your country can do for you, ask what you can do for your country/, lines[2])
end
def test_to_webvtt
whisper = Whisper::Context.new("base.en")
whisper.transcribe AUDIO, @params
lines = whisper.to_webvtt.lines
assert_equal "WEBVTT\n", lines[0]
assert_equal "\n", lines[1]
assert_match(/\A\d+\n/, lines[2])
assert_match(/\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}\n/, lines[3])
assert_match(/ask not what your country can do for you, ask what you can do for your country/, lines[4])
end
sub_test_case "Format needs escape" do
def setup
@whisper = Whisper::Context.new("base.en")
@whisper.transcribe AUDIO, Whisper::Params.new
segment = @whisper.each_segment.first
segment.define_singleton_method :text do
"& so my fellow Americans --> ask not what your country can do for you <-- ask what you can do for your country."
end
@whisper.define_singleton_method :each_segment do
Enumerator.new(3) {|yielder| 3.times {yielder << segment}}
end
end
def test_to_srt_escape
assert_equal "&amp; so my fellow Americans --&gt; ask not what your country can do for you &lt;-- ask what you can do for your country.\n", @whisper.to_srt.lines[2]
end
def test_to_webvtt_escape
assert_equal "&amp; so my fellow Americans --&gt; ask not what your country can do for you &lt;-- ask what you can do for your country.\n", @whisper.to_webvtt.lines[4]
end
end
end end

View File

@ -3,7 +3,8 @@ require_relative "extsources"
Gem::Specification.new do |s| Gem::Specification.new do |s|
s.name = "whispercpp" s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"] s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.3' s.version = '1.3.1'
s.date = '2024-12-19'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com' s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md'] s.extra_rdoc_files = ['LICENSE', 'README.md']
@ -14,19 +15,18 @@ Gem::Specification.new do |s|
if s.extra_rdoc_files.include?(basename) if s.extra_rdoc_files.include?(basename)
basename basename
else else
file.sub("../..", "ext/sources") file.sub("../..", "ext")
.sub("../javascript", "ext/sources/bindings/javascript")
end end
} }
s.summary = %q{Ruby whisper.cpp bindings} s.summary = %q{Ruby whisper.cpp bindings}
s.test_files = s.files.select {|file| file.start_with? "test/"} s.test_files = s.files.select {|file| file.start_with? "tests/"}
s.extensions << 'ext/extconf.rb' s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0' s.required_ruby_version = '>= 3.1.0'
#### Documentation and testing. #### Documentation and testing.
s.homepage = 'https://github.com/ggml-org/whisper.cpp' s.homepage = 'https://github.com/ggerganov/whisper.cpp'
s.rdoc_options = ['--main', 'README.md'] s.rdoc_options = ['--main', 'README.md']

View File

@ -41,11 +41,6 @@ COMMON_CMAKE_ARGS=(
-DGGML_OPENMP=${GGML_OPENMP} -DGGML_OPENMP=${GGML_OPENMP}
) )
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
echo "Detected Xcode version: $XCODE_VERSION"
check_required_tool() { check_required_tool() {
local tool=$1 local tool=$1
local install_message=$2 local install_message=$2
@ -340,28 +335,21 @@ combine_static_libraries() {
# Platform-specific post-processing for device builds # Platform-specific post-processing for device builds
if [[ "$is_simulator" == "false" ]]; then if [[ "$is_simulator" == "false" ]]; then
if command -v xcrun vtool &>/dev/null; then if command -v vtool &>/dev/null; then
case "$platform" in case "$platform" in
"ios") "ios")
echo "Marking binary as a framework binary for iOS..." echo "Marking binary as a framework binary for iOS..."
xcrun vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \ vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}" -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;; ;;
"visionos") "visionos")
echo "Marking binary as a framework binary for visionOS..." echo "Marking binary as a framework binary for visionOS..."
if [[ "$MAJOR_VERSION" -gt 16 ]] || [[ "$MAJOR_VERSION" -eq 16 && "$MINOR_VERSION" -gt 2 ]]; then vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
echo "Xcode version greater than 16.2, using visionOS."
VISION_OS_BUILD_VERSION="visionos"
else
echo "Xcode version less than or equal to 16.2, using xros."
VISION_OS_BUILD_VERSION="xros"
fi
xcrun vtool -set-build-version ${VISION_OS_BUILD_VERSION} ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}" -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;; ;;
"tvos") "tvos")
echo "Marking binary as a framework binary for tvOS..." echo "Marking binary as a framework binary for tvOS..."
xcrun vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \ vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}" -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
;; ;;
esac esac

View File

@ -105,7 +105,6 @@ else()
add_subdirectory(bench) add_subdirectory(bench)
add_subdirectory(server) add_subdirectory(server)
add_subdirectory(quantize) add_subdirectory(quantize)
add_subdirectory(vad-speech-segments)
if (WHISPER_SDL2) if (WHISPER_SDL2)
add_subdirectory(stream) add_subdirectory(stream)
add_subdirectory(command) add_subdirectory(command)

View File

@ -1,10 +1,8 @@
# whisper.cpp Node.js addon # addon
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js). This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
It can be used as a reference for using the whisper.cpp project in other node projects. It can be used as a reference for using the whisper.cpp project in other node projects.
This addon now supports **Voice Activity Detection (VAD)** for improved transcription performance.
## Install ## Install
```shell ```shell
@ -28,88 +26,12 @@ For Electron addon and cmake-js options, you can see [cmake-js](https://github.c
## Run ## Run
### Basic Usage
```shell ```shell
cd examples/addon.node cd examples/addon.node
node index.js --language='language' --model='model-path' --fname_inp='file-path' node index.js --language='language' --model='model-path' --fname_inp='file-path'
``` ```
### VAD (Voice Activity Detection) Usage Because this is a simple Demo, only the above parameters are set in the node environment.
Run the VAD example with performance comparison: Other parameters can also be specified in the node environment.
```shell
node vad-example.js
```
## Voice Activity Detection (VAD) Support
VAD can significantly improve transcription performance by only processing speech segments, which is especially beneficial for audio files with long periods of silence.
### VAD Model Setup
Before using VAD, download a VAD model:
```shell
# From the whisper.cpp root directory
./models/download-vad-model.sh silero-v5.1.2
```
### VAD Parameters
All VAD parameters are optional and have sensible defaults:
- `vad`: Enable VAD (default: false)
- `vad_model`: Path to VAD model file (required when VAD enabled)
- `vad_threshold`: Speech detection threshold 0.0-1.0 (default: 0.5)
- `vad_min_speech_duration_ms`: Min speech duration in ms (default: 250)
- `vad_min_silence_duration_ms`: Min silence duration in ms (default: 100)
- `vad_max_speech_duration_s`: Max speech duration in seconds (default: FLT_MAX)
- `vad_speech_pad_ms`: Speech padding in ms (default: 30)
- `vad_samples_overlap`: Sample overlap 0.0-1.0 (default: 0.1)
### JavaScript API Example
```javascript
const path = require("path");
const { whisper } = require(path.join(__dirname, "../../build/Release/addon.node"));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
// With VAD enabled
const vadParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
vad: true,
vad_model: path.join(__dirname, "../../models/ggml-silero-v5.1.2.bin"),
vad_threshold: 0.5,
progress_callback: (progress) => console.log(`Progress: ${progress}%`)
};
whisperAsync(vadParams).then(result => console.log(result));
```
## Supported Parameters
Both traditional whisper.cpp parameters and new VAD parameters are supported:
- `language`: Language code (e.g., "en", "es", "fr")
- `model`: Path to whisper model file
- `fname_inp`: Path to input audio file
- `use_gpu`: Enable GPU acceleration (default: true)
- `flash_attn`: Enable flash attention (default: false)
- `no_prints`: Disable console output (default: false)
- `no_timestamps`: Disable timestamps (default: false)
- `detect_language`: Auto-detect language (default: false)
- `audio_ctx`: Audio context size (default: 0)
- `max_len`: Maximum segment length (default: 0)
- `max_context`: Maximum context size (default: -1)
- `prompt`: Initial prompt for decoder
- `comma_in_time`: Use comma in timestamps (default: true)
- `print_progress`: Print progress info (default: false)
- `progress_callback`: Progress callback function
- VAD parameters (see above section)

View File

@ -1,133 +1,31 @@
const { join } = require('path'); const path = require("path");
const { whisper } = require('../../../build/Release/addon.node'); const { whisper } = require(path.join(
const { promisify } = require('util'); __dirname,
"../../../build/Release/addon.node"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper); const whisperAsync = promisify(whisper);
const commonParams = { const whisperParamsMock = {
language: 'en', language: "en",
model: join(__dirname, '../../../models/ggml-base.en.bin'), model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: join(__dirname, '../../../samples/jfk.wav'), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true, use_gpu: true,
flash_attn: false, flash_attn: false,
no_prints: true, no_prints: true,
comma_in_time: false,
translate: true,
no_timestamps: false, no_timestamps: false,
detect_language: false,
audio_ctx: 0, audio_ctx: 0,
max_len: 0 max_len: 0,
}; };
describe('Whisper.cpp Node.js addon with VAD support', () => { describe("Run whisper.node", () => {
test('Basic whisper transcription without VAD', async () => { test("it should receive a non-empty value", async () => {
const params = { let result = await whisperAsync(whisperParamsMock);
...commonParams,
vad: false
};
const result = await whisperAsync(params); expect(result.length).toBeGreaterThan(0);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
expect(result.transcription.length).toBeGreaterThan(0);
// Check that we got some transcription text
const text = result.transcription.map(segment => segment[2]).join(' ');
expect(text.length).toBeGreaterThan(0);
expect(text.toLowerCase()).toContain('ask not');
}, 30000);
test('VAD parameters validation', async () => {
// Test with invalid VAD model - should return empty transcription
const invalidParams = {
...commonParams,
vad: true,
vad_model: 'non-existent-model.bin',
vad_threshold: 0.5
};
// This should handle the error gracefully and return empty transcription
const result = await whisperAsync(invalidParams);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
// When VAD model doesn't exist, it should return empty transcription
expect(result.transcription.length).toBe(0);
}, 10000); }, 10000);
test('VAD parameter parsing', async () => {
// Test that VAD parameters are properly parsed (even if VAD model doesn't exist)
const vadParams = {
...commonParams,
vad: false, // Disabled so no model required
vad_threshold: 0.7,
vad_min_speech_duration_ms: 300,
vad_min_silence_duration_ms: 150,
vad_max_speech_duration_s: 45.0,
vad_speech_pad_ms: 50,
vad_samples_overlap: 0.15
};
const result = await whisperAsync(vadParams);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
}, 30000);
test('Progress callback with VAD disabled', async () => {
let progressCalled = false;
let lastProgress = 0;
const params = {
...commonParams,
vad: false,
progress_callback: (progress) => {
progressCalled = true;
lastProgress = progress;
expect(progress).toBeGreaterThanOrEqual(0);
expect(progress).toBeLessThanOrEqual(100);
}
};
const result = await whisperAsync(params);
expect(progressCalled).toBe(true);
expect(lastProgress).toBe(100);
expect(typeof result).toBe('object');
}, 30000);
test('Language detection without VAD', async () => {
const params = {
...commonParams,
vad: false,
detect_language: true,
language: 'auto'
};
const result = await whisperAsync(params);
expect(typeof result).toBe('object');
expect(typeof result.language).toBe('string');
expect(result.language.length).toBeGreaterThan(0);
}, 30000);
test('Basic transcription with all VAD parameters set', async () => {
// Test with VAD disabled but all parameters set to ensure no crashes
const params = {
...commonParams,
vad: false, // Disabled so it works without VAD model
vad_model: '', // Empty model path
vad_threshold: 0.6,
vad_min_speech_duration_ms: 200,
vad_min_silence_duration_ms: 80,
vad_max_speech_duration_s: 25.0,
vad_speech_pad_ms: 40,
vad_samples_overlap: 0.08
};
const result = await whisperAsync(params);
expect(typeof result).toBe('object');
expect(Array.isArray(result.transcription)).toBe(true);
expect(result.transcription.length).toBeGreaterThan(0);
}, 30000);
}); });

View File

@ -9,7 +9,6 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <cfloat>
struct whisper_params { struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -39,7 +38,6 @@ struct whisper_params {
bool print_progress = false; bool print_progress = false;
bool no_timestamps = false; bool no_timestamps = false;
bool no_prints = false; bool no_prints = false;
bool detect_language= false;
bool use_gpu = true; bool use_gpu = true;
bool flash_attn = false; bool flash_attn = false;
bool comma_in_time = true; bool comma_in_time = true;
@ -52,16 +50,6 @@ struct whisper_params {
std::vector<std::string> fname_out = {}; std::vector<std::string> fname_out = {};
std::vector<float> pcmf32 = {}; // mono-channel F32 PCM std::vector<float> pcmf32 = {}; // mono-channel F32 PCM
// Voice Activity Detection (VAD) parameters
bool vad = false;
std::string vad_model = "";
float vad_threshold = 0.5f;
int vad_min_speech_duration_ms = 250;
int vad_min_silence_duration_ms = 100;
float vad_max_speech_duration_s = FLT_MAX;
int vad_speech_pad_ms = 30;
float vad_samples_overlap = 0.1f;
}; };
struct whisper_print_user_data { struct whisper_print_user_data {
@ -94,7 +82,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
t1 = whisper_full_get_segment_t1(ctx, i); t1 = whisper_full_get_segment_t1(ctx, i);
} }
if (!params.no_timestamps && !params.no_prints) { if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
} }
@ -125,14 +113,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
// colorful print bug // colorful print bug
// //
if (!params.no_prints) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text); printf("%s%s", speaker.c_str(), text);
}
// with timestamps or speakers: each segment on new line // with timestamps or speakers: each segment on new line
if ((!params.no_timestamps || params.diarize) && !params.no_prints) { if (!params.no_timestamps || params.diarize) {
printf("\n"); printf("\n");
} }
@ -142,11 +128,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
void cb_log_disable(enum ggml_log_level, const char *, void *) {} void cb_log_disable(enum ggml_log_level, const char *, void *) {}
struct whisper_result {
std::vector<std::vector<std::string>> segments;
std::string language;
};
class ProgressWorker : public Napi::AsyncWorker { class ProgressWorker : public Napi::AsyncWorker {
public: public:
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env) ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
@ -177,27 +158,15 @@ class ProgressWorker : public Napi::AsyncWorker {
void OnOK() override { void OnOK() override {
Napi::HandleScope scope(Env()); Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
if (params.detect_language) { for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object resultObj = Napi::Object::New(Env());
resultObj.Set("language", Napi::String::New(Env(), result.language));
Callback().Call({Env().Null(), resultObj});
}
Napi::Object returnObj = Napi::Object::New(Env());
if (!result.language.empty()) {
returnObj.Set("language", Napi::String::New(Env(), result.language));
}
Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size());
for (uint64_t i = 0; i < result.segments.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3); Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) { for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result.segments[i][j]); tmp[j] = Napi::String::New(Env(), result[i][j]);
} }
transcriptionArray[i] = tmp; res[i] = tmp;
} }
returnObj.Set("transcription", transcriptionArray); Callback().Call({Env().Null(), res});
Callback().Call({Env().Null(), returnObj});
} }
// Progress callback function - using thread-safe function // Progress callback function - using thread-safe function
@ -214,12 +183,12 @@ class ProgressWorker : public Napi::AsyncWorker {
private: private:
whisper_params params; whisper_params params;
whisper_result result; std::vector<std::vector<std::string>> result;
Napi::Env env; Napi::Env env;
Napi::ThreadSafeFunction tsfn; Napi::ThreadSafeFunction tsfn;
// Custom run function with progress callback support // Custom run function with progress callback support
int run_with_progress(whisper_params &params, whisper_result & result) { int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) { if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL); whisper_log_set(cb_log_disable, NULL);
} }
@ -308,8 +277,7 @@ class ProgressWorker : public Napi::AsyncWorker {
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special; wparams.print_special = params.print_special;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.language = params.detect_language ? "auto" : params.language.c_str(); wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms; wparams.offset_ms = params.offset_t_ms;
@ -344,16 +312,16 @@ class ProgressWorker : public Napi::AsyncWorker {
}; };
wparams.progress_callback_user_data = this; wparams.progress_callback_user_data = this;
// Set VAD parameters // Abort mechanism example
wparams.vad = params.vad; {
wparams.vad_model_path = params.vad_model.c_str(); static bool is_aborted = false; // Note: this should be atomic to avoid data races
wparams.vad_params.threshold = params.vad_threshold; wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms; bool is_aborted = *(bool*)user_data;
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms; return !is_aborted;
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s; };
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms; wparams.encoder_begin_callback_user_data = &is_aborted;
wparams.vad_params.samples_overlap = params.vad_samples_overlap; }
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n"); fprintf(stderr, "failed to process audio\n");
@ -362,20 +330,16 @@ class ProgressWorker : public Napi::AsyncWorker {
} }
} }
if (params.detect_language || params.language == "auto") {
result.language = whisper_lang_str(whisper_full_lang_id(ctx));
}
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
result.segments.resize(n_segments); result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time)); result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time)); result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result.segments[i].emplace_back(text); result[i].emplace_back(text);
} }
whisper_print_timings(ctx); whisper_print_timings(ctx);
@ -396,52 +360,13 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string language = whisper_params.Get("language").As<Napi::String>(); std::string language = whisper_params.Get("language").As<Napi::String>();
std::string model = whisper_params.Get("model").As<Napi::String>(); std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>(); std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
bool use_gpu = true; bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
if (whisper_params.Has("use_gpu") && whisper_params.Get("use_gpu").IsBoolean()) { bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>(); bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
} int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
bool flash_attn = false; int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
if (whisper_params.Has("flash_attn") && whisper_params.Get("flash_attn").IsBoolean()) {
flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
}
bool no_prints = false;
if (whisper_params.Has("no_prints") && whisper_params.Get("no_prints").IsBoolean()) {
no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
}
bool no_timestamps = false;
if (whisper_params.Has("no_timestamps") && whisper_params.Get("no_timestamps").IsBoolean()) {
no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
}
bool detect_language = false;
if (whisper_params.Has("detect_language") && whisper_params.Get("detect_language").IsBoolean()) {
detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
}
int32_t audio_ctx = 0;
if (whisper_params.Has("audio_ctx") && whisper_params.Get("audio_ctx").IsNumber()) {
audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
}
bool comma_in_time = true;
if (whisper_params.Has("comma_in_time") && whisper_params.Get("comma_in_time").IsBoolean()) {
comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
}
int32_t max_len = 0;
if (whisper_params.Has("max_len") && whisper_params.Get("max_len").IsNumber()) {
max_len = whisper_params.Get("max_len").As<Napi::Number>();
}
// Add support for max_context
int32_t max_context = -1;
if (whisper_params.Has("max_context") && whisper_params.Get("max_context").IsNumber()) {
max_context = whisper_params.Get("max_context").As<Napi::Number>();
}
// support prompt // support prompt
std::string prompt = ""; std::string prompt = "";
@ -451,7 +376,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
// Add support for print_progress // Add support for print_progress
bool print_progress = false; bool print_progress = false;
if (whisper_params.Has("print_progress") && whisper_params.Get("print_progress").IsBoolean()) { if (whisper_params.Has("print_progress")) {
print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>(); print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
} }
// Add support for progress_callback // Add support for progress_callback
@ -460,47 +385,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>(); progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
} }
// Add support for VAD parameters
bool vad = false;
if (whisper_params.Has("vad") && whisper_params.Get("vad").IsBoolean()) {
vad = whisper_params.Get("vad").As<Napi::Boolean>();
}
std::string vad_model = "";
if (whisper_params.Has("vad_model") && whisper_params.Get("vad_model").IsString()) {
vad_model = whisper_params.Get("vad_model").As<Napi::String>();
}
float vad_threshold = 0.5f;
if (whisper_params.Has("vad_threshold") && whisper_params.Get("vad_threshold").IsNumber()) {
vad_threshold = whisper_params.Get("vad_threshold").As<Napi::Number>();
}
int vad_min_speech_duration_ms = 250;
if (whisper_params.Has("vad_min_speech_duration_ms") && whisper_params.Get("vad_min_speech_duration_ms").IsNumber()) {
vad_min_speech_duration_ms = whisper_params.Get("vad_min_speech_duration_ms").As<Napi::Number>();
}
int vad_min_silence_duration_ms = 100;
if (whisper_params.Has("vad_min_silence_duration_ms") && whisper_params.Get("vad_min_silence_duration_ms").IsNumber()) {
vad_min_silence_duration_ms = whisper_params.Get("vad_min_silence_duration_ms").As<Napi::Number>();
}
float vad_max_speech_duration_s = FLT_MAX;
if (whisper_params.Has("vad_max_speech_duration_s") && whisper_params.Get("vad_max_speech_duration_s").IsNumber()) {
vad_max_speech_duration_s = whisper_params.Get("vad_max_speech_duration_s").As<Napi::Number>();
}
int vad_speech_pad_ms = 30;
if (whisper_params.Has("vad_speech_pad_ms") && whisper_params.Get("vad_speech_pad_ms").IsNumber()) {
vad_speech_pad_ms = whisper_params.Get("vad_speech_pad_ms").As<Napi::Number>();
}
float vad_samples_overlap = 0.1f;
if (whisper_params.Has("vad_samples_overlap") && whisper_params.Get("vad_samples_overlap").IsNumber()) {
vad_samples_overlap = whisper_params.Get("vad_samples_overlap").As<Napi::Number>();
}
Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
std::vector<float> pcmf32_vec; std::vector<float> pcmf32_vec;
if (pcmf32Value.IsTypedArray()) { if (pcmf32Value.IsTypedArray()) {
@ -523,20 +407,8 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.pcmf32 = pcmf32_vec; params.pcmf32 = pcmf32_vec;
params.comma_in_time = comma_in_time; params.comma_in_time = comma_in_time;
params.max_len = max_len; params.max_len = max_len;
params.max_context = max_context;
params.print_progress = print_progress; params.print_progress = print_progress;
params.prompt = prompt; params.prompt = prompt;
params.detect_language = detect_language;
// Set VAD parameters
params.vad = vad;
params.vad_model = vad_model;
params.vad_threshold = vad_threshold;
params.vad_min_speech_duration_ms = vad_min_speech_duration_ms;
params.vad_min_silence_duration_ms = vad_min_silence_duration_ms;
params.vad_max_speech_duration_s = vad_max_speech_duration_s;
params.vad_speech_pad_ms = vad_speech_pad_ms;
params.vad_samples_overlap = vad_samples_overlap;
Napi::Function callback = info[1].As<Napi::Function>(); Napi::Function callback = info[1].As<Napi::Function>();
// Create a new Worker class with progress callback support // Create a new Worker class with progress callback support

View File

@ -17,7 +17,6 @@ const whisperParams = {
comma_in_time: false, comma_in_time: false,
translate: true, translate: true,
no_timestamps: false, no_timestamps: false,
detect_language: false,
audio_ctx: 0, audio_ctx: 0,
max_len: 0, max_len: 0,
progress_callback: (progress) => { progress_callback: (progress) => {
@ -32,8 +31,6 @@ const params = Object.fromEntries(
const [key, value] = item.slice(2).split("="); const [key, value] = item.slice(2).split("=");
if (key === "audio_ctx") { if (key === "audio_ctx") {
whisperParams[key] = parseInt(value); whisperParams[key] = parseInt(value);
} else if (key === "detect_language") {
whisperParams[key] = value === "true";
} else { } else {
whisperParams[key] = value; whisperParams[key] = value;
} }

View File

@ -1,132 +0,0 @@
const path = require("path");
const { whisper } = require(path.join(
__dirname,
"../../build/Release/addon.node"
));
const { promisify } = require("util");
const whisperAsync = promisify(whisper);
// Example with VAD enabled
const vadParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
use_gpu: true,
flash_attn: false,
no_prints: false,
comma_in_time: true,
translate: false,
no_timestamps: false,
detect_language: false,
audio_ctx: 0,
max_len: 0,
// VAD parameters
vad: true,
vad_model: path.join(__dirname, "../../models/ggml-silero-v5.1.2.bin"), // You need to download this model
vad_threshold: 0.5,
vad_min_speech_duration_ms: 250,
vad_min_silence_duration_ms: 100,
vad_max_speech_duration_s: 30.0,
vad_speech_pad_ms: 30,
vad_samples_overlap: 0.1,
progress_callback: (progress) => {
console.log(`VAD Transcription progress: ${progress}%`);
}
};
// Example without VAD (traditional approach)
const traditionalParams = {
language: "en",
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
use_gpu: true,
flash_attn: false,
no_prints: false,
comma_in_time: true,
translate: false,
no_timestamps: false,
detect_language: false,
audio_ctx: 0,
max_len: 0,
vad: false, // Explicitly disable VAD
progress_callback: (progress) => {
console.log(`Traditional transcription progress: ${progress}%`);
}
};
async function runVADExample() {
try {
console.log("=== Whisper.cpp Node.js VAD Example ===\n");
// Check if VAD model exists
const fs = require('fs');
if (!fs.existsSync(vadParams.vad_model)) {
console.log("⚠️ VAD model not found. Please download the VAD model first:");
console.log(" ./models/download-vad-model.sh silero-v5.1.2");
console.log(" Or run: python models/convert-silero-vad-to-ggml.py");
console.log("\n Falling back to traditional transcription without VAD...\n");
// Run without VAD
console.log("🎵 Running traditional transcription...");
const traditionalResult = await whisperAsync(traditionalParams);
console.log("\n📝 Traditional transcription result:");
console.log(traditionalResult);
return;
}
console.log("🎵 Running transcription with VAD enabled...");
console.log("VAD Parameters:");
console.log(` - Threshold: ${vadParams.vad_threshold}`);
console.log(` - Min speech duration: ${vadParams.vad_min_speech_duration_ms}ms`);
console.log(` - Min silence duration: ${vadParams.vad_min_silence_duration_ms}ms`);
console.log(` - Max speech duration: ${vadParams.vad_max_speech_duration_s}s`);
console.log(` - Speech padding: ${vadParams.vad_speech_pad_ms}ms`);
console.log(` - Samples overlap: ${vadParams.vad_samples_overlap}\n`);
const startTime = Date.now();
const vadResult = await whisperAsync(vadParams);
const vadDuration = Date.now() - startTime;
console.log("\n✅ VAD transcription completed!");
console.log(`⏱️ Processing time: ${vadDuration}ms`);
console.log("\n📝 VAD transcription result:");
console.log(vadResult);
// Compare with traditional approach
console.log("\n🔄 Running traditional transcription for comparison...");
const traditionalStartTime = Date.now();
const traditionalResult = await whisperAsync(traditionalParams);
const traditionalDuration = Date.now() - traditionalStartTime;
console.log("\n✅ Traditional transcription completed!");
console.log(`⏱️ Processing time: ${traditionalDuration}ms`);
console.log("\n📝 Traditional transcription result:");
console.log(traditionalResult);
// Performance comparison
console.log("\n📊 Performance Comparison:");
console.log(`VAD: ${vadDuration}ms`);
console.log(`Traditional: ${traditionalDuration}ms`);
const speedup = traditionalDuration / vadDuration;
if (speedup > 1) {
console.log(`🚀 VAD is ${speedup.toFixed(2)}x faster!`);
} else {
console.log(` Traditional approach was ${(1/speedup).toFixed(2)}x faster in this case.`);
}
} catch (error) {
console.error("❌ Error during transcription:", error);
}
}
// Run the example
if (require.main === module) {
runVADExample();
}
module.exports = {
runVADExample,
vadParams,
traditionalParams
};

View File

@ -35,7 +35,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
-s INITIAL_MEMORY=2000MB \ -s INITIAL_MEMORY=2000MB \
-s TOTAL_MEMORY=2000MB \ -s TOTAL_MEMORY=2000MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \
") ")

View File

@ -28,10 +28,5 @@ to the server's HTTP path:
``` ```
# copy the produced page to your HTTP path # copy the produced page to your HTTP path
cp bin/bench.wasm/* /path/to/html/ cp bin/bench.wasm/* /path/to/html/
cp bin/libbench.js /path/to/html/
cp bin/libbench.worker.js /path/to/html/ cp bin/libbench.worker.js /path/to/html/
``` ```
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
> longer generated and the worker is embedded in the main JS file. So the worker
> file will not be geneated for versions later than `3.1.58`.

View File

@ -4,7 +4,7 @@ A very basic tool for benchmarking the inference performance on your device. The
the transformer on some random audio data and records the execution time. This way we can have an objective comparison the transformer on some random audio data and records the execution time. This way we can have an objective comparison
of the performance of the model for various setups. of the performance of the model for various setups.
Benchmark results are tracked in the following Github issue: https://github.com/ggml-org/whisper.cpp/issues/89 Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
```bash ```bash
# run the bench too on the small.en model using 4 threads # run the bench too on the small.en model using 4 threads
@ -40,7 +40,7 @@ system_info: n_threads = 4 | AVX2 = 0 | AVX512 = 0 | NEON = 1 | FP16_VA = 1 | WA
If you wish, you can submit these results here: If you wish, you can submit these results here:
https://github.com/ggml-org/whisper.cpp/issues/89 https://github.com/ggerganov/whisper.cpp/issues/89
Please include the following information: Please include the following information:

View File

@ -66,12 +66,13 @@ static int whisper_bench_full(const whisper_params & params) {
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info()); fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
} }
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) { if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n"); fprintf(stderr, "error: failed to initialize whisper context\n");
return 2; return 2;
@ -155,8 +156,6 @@ static int whisper_bench_full(const whisper_params & params) {
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params; whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) { if (whisper_params_parse(argc, argv, params) == false) {

View File

@ -6,8 +6,7 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
``` ```
./build/bin/whisper-cli -h ./build/bin/whisper-cli -h
usage: ./build/bin/whisper-cli [options] file0 file1 ... usage: ./build-pkg/bin/whisper-cli [options] file0.wav file1.wav ...
supported audio formats: flac, mp3, ogg, wav
options: options:
-h, --help [default] show this help message and exit -h, --help [default] show this help message and exit
@ -25,7 +24,6 @@ options:
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-nth N, --no-speech-thold N [0.60 ] no speech threshold
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1 -tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1 -tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel) -debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
@ -52,13 +50,12 @@ options:
-dl, --detect-language [false ] exit after automatically detecting language -dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens) --prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input audio file path -f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps -dtw MODEL --dtw MODEL [ ] compute token-level timestamps
-ls, --log-score [false ] log best decoder scores of tokens -ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU -ng, --no-gpu [false ] disable GPU
-fa, --flash-attn [false ] flash attention -fa, --flash-attn [false ] flash attention
-sns, --suppress-nst [false ] suppress non-speech tokens
--suppress-regex REGEX [ ] regular expression matching tokens to suppress --suppress-regex REGEX [ ] regular expression matching tokens to suppress
--grammar GRAMMAR [ ] GBNF grammar to guide decoding --grammar GRAMMAR [ ] GBNF grammar to guide decoding
--grammar-rule RULE [ ] top-level GBNF grammar rule name --grammar-rule RULE [ ] top-level GBNF grammar rule name

View File

@ -11,7 +11,6 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <cstring> #include <cstring>
#include <cfloat>
#if defined(_WIN32) #if defined(_WIN32)
#ifndef NOMINMAX #ifndef NOMINMAX
@ -20,6 +19,10 @@
#include <windows.h> #include <windows.h>
#endif #endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// helper function to replace substrings // helper function to replace substrings
static void replace_all(std::string & s, const std::string & search, const std::string & replace) { static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) { for (size_t pos = 0; ; pos += replace.length()) {
@ -70,7 +73,6 @@ struct whisper_params {
bool no_prints = false; bool no_prints = false;
bool print_special = false; bool print_special = false;
bool print_colors = false; bool print_colors = false;
bool print_confidence= false;
bool print_progress = false; bool print_progress = false;
bool no_timestamps = false; bool no_timestamps = false;
bool log_score = false; bool log_score = false;
@ -99,16 +101,6 @@ struct whisper_params {
std::vector<std::string> fname_out = {}; std::vector<std::string> fname_out = {};
grammar_parser::parse_state grammar_parsed; grammar_parser::parse_state grammar_parsed;
// Voice Activity Detection (VAD) parameters
bool vad = false;
std::string vad_model = "";
float vad_threshold = 0.5f;
int vad_min_speech_duration_ms = 250;
int vad_min_silence_duration_ms = 100;
float vad_max_speech_duration_s = FLT_MAX;
int vad_speech_pad_ms = 30;
float vad_samples_overlap = 0.1f;
}; };
static void whisper_print_usage(int argc, char ** argv, const whisper_params & params); static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@ -180,7 +172,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if ( arg == "--print-confidence"){ params.print_confidence= true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); } else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
@ -198,15 +189,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; } else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; } else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
// Voice Activity Detection (VAD)
else if ( arg == "--vad") { params.vad = true; }
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); }
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); }
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); }
else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); }
else { else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
@ -259,7 +241,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
@ -277,18 +258,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
// Voice Activity Detection (VAD) parameters
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
std::string("FLT_MAX").c_str() :
std::to_string(params.vad_max_speech_duration_s).c_str());
fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -389,26 +358,6 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
} }
} else if (params.print_confidence) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
int style_idx = 2; // High confidence - dim
if (p < 0.33) {
style_idx = 0; // Low confidence - inverse (highlighted)
} else if (p < 0.66) {
style_idx = 1; // Medium confidence - underlined
}
printf("%s%s%s%s", speaker.c_str(), k_styles[style_idx].c_str(), text, "\033[0m");
}
} else { } else {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
@ -430,7 +379,15 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
} }
} }
static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
@ -445,9 +402,19 @@ static void output_txt(struct whisper_context * ctx, std::ofstream & fout, const
fout << speaker << text << "\n"; fout << speaker << text << "\n";
} }
return true;
} }
static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "WEBVTT\n\n"; fout << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
@ -467,9 +434,19 @@ static void output_vtt(struct whisper_context * ctx, std::ofstream & fout, const
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << speaker << text << "\n\n"; fout << speaker << text << "\n\n";
} }
return true;
} }
static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
@ -486,6 +463,8 @@ static void output_srt(struct whisper_context * ctx, std::ofstream & fout, const
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
fout << speaker << text << "\n\n"; fout << speaker << text << "\n\n";
} }
return true;
} }
static char * escape_double_quotes_and_backslashes(const char * str) { static char * escape_double_quotes_and_backslashes(const char * str) {
@ -551,7 +530,15 @@ static char * escape_double_quotes_in_csv(const char * str) {
return escaped; return escaped;
} }
static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,"; fout << "start,end,";
if (params.diarize && pcmf32s.size() == 2) if (params.diarize && pcmf32s.size() == 2)
@ -574,9 +561,14 @@ static void output_csv(struct whisper_context * ctx, std::ofstream & fout, const
} }
fout << "\"" << text_escaped << "\"\n"; fout << "\"" << text_escaped << "\"\n";
} }
return true;
} }
static void output_score(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) { static bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
// fprintf(stderr,"segments: %d\n",n_segments); // fprintf(stderr,"segments: %d\n",n_segments);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
@ -589,14 +581,16 @@ static void output_score(struct whisper_context * ctx, std::ofstream & fout, con
// fprintf(stderr,"token: %s %f\n",token,probability); // fprintf(stderr,"token: %s %f\n",token,probability);
} }
} }
return true;
} }
static void output_json( static bool output_json(
struct whisper_context * ctx, struct whisper_context * ctx,
std::ofstream & fout, const char * fname,
const whisper_params & params, const whisper_params & params,
std::vector<std::vector<float>> pcmf32s) { std::vector<std::vector<float>> pcmf32s,
const bool full = params.output_jsn_full; bool full) {
std::ofstream fout(fname);
int indent = 0; int indent = 0;
auto doindent = [&]() { auto doindent = [&]() {
@ -676,6 +670,12 @@ static void output_json(
end_obj(end); end_obj(end);
}; };
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj(nullptr); start_obj(nullptr);
value_s("systeminfo", whisper_print_system_info(), false); value_s("systeminfo", whisper_print_system_info(), false);
start_obj("model"); start_obj("model");
@ -749,12 +749,17 @@ static void output_json(
end_arr(true); end_arr(true);
end_obj(true); end_obj(true);
return true;
} }
// karaoke video generation // karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles // outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments // TODO: font parameter adjustments
static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s, const char * fname_inp, float t_sec, const char * fname_out) { static bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
static const char * font = params.font_path.c_str(); static const char * font = params.font_path.c_str();
std::ifstream fin(font); std::ifstream fin(font);
@ -870,12 +875,20 @@ static bool output_wts(struct whisper_context * ctx, std::ofstream & fout, const
fout.close(); fout.close();
fprintf(stderr, "# %s: run 'source %s' to generate karaoke video\n", __func__, fname_out); fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname);
return true; return true;
} }
static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { static bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
fout << "[by:whisper.cpp]\n"; fout << "[by:whisper.cpp]\n";
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
@ -903,14 +916,14 @@ static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n"; fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
} }
return true;
} }
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
#if defined(_WIN32) #if defined(_WIN32)
// Set the console output code page to UTF-8, while command line arguments // Set the console output code page to UTF-8, while command line arguments
// are still encoded in the system's code page. In this way, we can print // are still encoded in the system's code page. In this way, we can print
@ -990,6 +1003,7 @@ int main(int argc, char ** argv) {
} }
// whisper init // whisper init
struct whisper_context_params cparams = whisper_context_default_params(); struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
@ -1052,55 +1066,8 @@ int main(int argc, char ** argv) {
} }
for (int f = 0; f < (int) params.fname_inp.size(); ++f) { for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto & fname_inp = params.fname_inp[f]; const auto fname_inp = params.fname_inp[f];
struct fout_factory { const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::string fname_out;
const size_t basename_length;
const bool is_stdout;
bool used_stdout;
decltype(whisper_print_segment_callback) * const print_segment_callback;
std::ofstream fout;
fout_factory (const std::string & fname_out_, const std::string & fname_inp, whisper_params & params) :
fname_out{!fname_out_.empty() ? fname_out_ : fname_inp},
basename_length{fname_out.size()},
is_stdout{fname_out == "-"},
used_stdout{},
print_segment_callback{is_stdout ? nullptr : whisper_print_segment_callback} {
if (!print_segment_callback) {
params.print_progress = false;
}
}
bool open(const char * ext, const char * function) {
if (is_stdout) {
if (used_stdout) {
fprintf(stderr, "warning: Not appending multiple file formats to stdout\n");
return false;
}
used_stdout = true;
#ifdef _WIN32
fout = std::ofstream{"CON"};
#else
fout = std::ofstream{"/dev/stdout"};
#endif
// Not using fprintf stderr here because it might equal stdout
// Also assuming /dev is mounted
return true;
}
fname_out.resize(basename_length);
fname_out += ext;
fout = std::ofstream{fname_out};
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
return false;
}
fprintf(stderr, "%s: saving output to '%s'\n", function, fname_out.c_str());
return true;
}
} fout_factory{f < (int) params.fname_out.size() ? params.fname_out[f] : "", fname_inp, params};
std::vector<float> pcmf32; // mono-channel F32 PCM std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
@ -1137,11 +1104,6 @@ int main(int argc, char ** argv) {
params.tinydiarize ? "tdrz = 1, " : "", params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1); params.no_timestamps ? 0 : 1);
if (params.print_colors) {
fprintf(stderr, "%s: color scheme: red (low confidence), yellow (medium), green (high confidence)\n", __func__);
} else if (params.print_confidence) {
fprintf(stderr, "%s: confidence: highlighted (low confidence), underlined (medium), dim (high confidence)\n", __func__);
}
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -1192,16 +1154,6 @@ int main(int argc, char ** argv) {
wparams.suppress_nst = params.suppress_nst; wparams.suppress_nst = params.suppress_nst;
wparams.vad = params.vad;
wparams.vad_model_path = params.vad_model.c_str();
wparams.vad_params.threshold = params.vad_threshold;
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 }; whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
const auto & grammar_parsed = params.grammar_parsed; const auto & grammar_parsed = params.grammar_parsed;
@ -1220,7 +1172,7 @@ int main(int argc, char ** argv) {
// this callback is called on each new segment // this callback is called on each new segment
if (!wparams.print_realtime) { if (!wparams.print_realtime) {
wparams.new_segment_callback = fout_factory.print_segment_callback; wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data; wparams.new_segment_callback_user_data = &user_data;
} }
@ -1262,26 +1214,54 @@ int main(int argc, char ** argv) {
// output stuff // output stuff
{ {
// macros to stringify function name printf("\n");
#define output_func(func, ext, param, ...) if (param && fout_factory.open(ext, #func)) {\
func(ctx, fout_factory.fout, params, __VA_ARGS__); \ // output to text file
if (params.output_txt) {
const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
} }
#define output_ext(ext, ...) output_func(output_##ext, "." #ext, params.output_##ext, __VA_ARGS__)
output_ext(txt, pcmf32s); // output to VTT file
output_ext(vtt, pcmf32s); if (params.output_vtt) {
output_ext(srt, pcmf32s); const auto fname_vtt = fname_out + ".vtt";
output_ext(wts, pcmf32s, fname_inp.c_str(), float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, fout_factory.fname_out.c_str()); output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
output_ext(csv, pcmf32s); }
output_func(output_json, ".json", params.output_jsn, pcmf32s);
output_ext(lrc, pcmf32s);
output_func(output_score, ".score.txt", params.log_score, pcmf32s);
#undef output_ext // output to SRT file
#undef output_func if (params.output_srt) {
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
}
if (fout_factory.is_stdout && !fout_factory.used_stdout) { // output to WTS file
fprintf(stderr, "warning: '--output-file -' used without any other '--output-*'"); if (params.output_wts) {
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
}
// output to score file
if (params.log_score) {
const auto fname_score = fname_out + ".score.txt";
output_score(ctx, fname_score.c_str(), params, pcmf32s);
} }
} }
} }

View File

@ -36,7 +36,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
-s INITIAL_MEMORY=1024MB \ -s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \ -s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \
") ")

View File

@ -28,10 +28,5 @@ To run the example in a different server, you need to copy the following files
to the server's HTTP path: to the server's HTTP path:
``` ```
cp bin/command.wasm/* /path/to/html/ cp bin/command.wasm/* /path/to/html/
cp bin/libcommand.js /path/to/html/
cp bin/libcommand.worker.js /path/to/html/ cp bin/libcommand.worker.js /path/to/html/
``` ```
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
> longer generated and the worker is embedded in the main JS file. So the worker
> file will not be geneated for versions later than `3.1.58`.

View File

@ -3,7 +3,7 @@
// Speak short text commands to the microphone. // Speak short text commands to the microphone.
// This program will detect your voice command and convert them to text. // This program will detect your voice command and convert them to text.
// //
// ref: https://github.com/ggml-org/whisper.cpp/issues/171 // ref: https://github.com/ggerganov/whisper.cpp/issues/171
// //
#include "common-sdl.h" #include "common-sdl.h"
@ -251,7 +251,7 @@ static std::vector<std::string> get_words(const std::string &txt) {
// command-list mode // command-list mode
// guide the transcription to match the most likely command from a provided list // guide the transcription to match the most likely command from a provided list
static int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params, std::ofstream &fout) { static int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: guided mode\n", __func__); fprintf(stderr, "%s: guided mode\n", __func__);
@ -444,16 +444,12 @@ static int process_command_list(struct whisper_context * ctx, audio_async &audio
const float prob = probs_id[0].first; const float prob = probs_id[0].first;
const int index = probs_id[0].second; const int index = probs_id[0].second;
const char * best_command = allowed_commands[index].c_str();
fprintf(stdout, "\n"); fprintf(stdout, "\n");
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
"\033[1m", best_command, "\033[0m", prob, "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count()); (int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
fprintf(stdout, "\n"); fprintf(stdout, "\n");
if (fout.is_open()) {
fout << best_command << std::endl;
}
} }
} }
@ -466,7 +462,7 @@ static int process_command_list(struct whisper_context * ctx, audio_async &audio
// always-prompt mode // always-prompt mode
// transcribe the voice into text after valid prompt // transcribe the voice into text after valid prompt
static int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params, std::ofstream & fout) { static int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true; bool is_running = true;
bool ask_prompt = true; bool ask_prompt = true;
@ -532,9 +528,6 @@ static int always_prompt_transcription(struct whisper_context * ctx, audio_async
if ((sim > 0.7f) && (command.size() > 0)) { if ((sim > 0.7f) && (command.size() > 0)) {
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
if (fout.is_open()) {
fout << command << std::endl;
}
} }
fprintf(stdout, "\n"); fprintf(stdout, "\n");
@ -549,7 +542,7 @@ static int always_prompt_transcription(struct whisper_context * ctx, audio_async
// general-purpose mode // general-purpose mode
// freely transcribe the voice into text // freely transcribe the voice into text
static int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params, std::ofstream & fout) { static int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true; bool is_running = true;
bool have_prompt = false; bool have_prompt = false;
bool ask_prompt = true; bool ask_prompt = true;
@ -669,10 +662,8 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
} else { } else {
// cut the prompt from the decoded text // cut the prompt from the decoded text
const std::string command = ::trim(txt.substr(best_len)); const std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
if (fout.is_open()) {
fout << command << std::endl;
}
} }
fprintf(stdout, "\n"); fprintf(stdout, "\n");
@ -687,8 +678,6 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params; whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) { if (whisper_params_parse(argc, argv, params) == false) {
@ -709,10 +698,6 @@ int main(int argc, char ** argv) {
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 2;
}
// print some info about the processing // print some info about the processing
{ {
@ -772,22 +757,13 @@ int main(int argc, char ** argv) {
} }
} }
std::ofstream fout;
if (params.fname_out.length() > 0) {
fout.open(params.fname_out);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
return 1;
}
}
if (ret_val == 0) { if (ret_val == 0) {
if (!params.commands.empty()) { if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params, fout); ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) { } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params, fout); ret_val = always_prompt_transcription(ctx, audio, params);
} else { } else {
ret_val = process_general_transcription(ctx, audio, params, fout); ret_val = process_general_transcription(ctx, audio, params);
} }
} }

View File

@ -26,6 +26,10 @@
#define MINIAUDIO_IMPLEMENTATION #define MINIAUDIO_IMPLEMENTATION
#include "miniaudio.h" #include "miniaudio.h"
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#ifdef _WIN32 #ifdef _WIN32
#include <fcntl.h> #include <fcntl.h>
#include <io.h> #include <io.h>
@ -112,19 +116,12 @@ bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std:
} }
if (stereo) { if (stereo) {
std::vector<float> stereo_data = pcmf32;
pcmf32.resize(frame_count);
for (uint64_t i = 0; i < frame_count; i++) {
pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]);
}
pcmf32s.resize(2); pcmf32s.resize(2);
pcmf32s[0].resize(frame_count); pcmf32s[0].resize(frame_count);
pcmf32s[1].resize(frame_count); pcmf32s[1].resize(frame_count);
for (uint64_t i = 0; i < frame_count; i++) { for (uint64_t i = 0; i < frame_count; i++) {
pcmf32s[0][i] = stereo_data[2*i]; pcmf32s[0][i] = pcmf32[2*i];
pcmf32s[1][i] = stereo_data[2*i + 1]; pcmf32s[1][i] = pcmf32[2*i + 1];
} }
} }

View File

@ -10,6 +10,10 @@
#include <regex> #include <regex>
#include <sstream> #include <sstream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Function to check if the next argument exists // Function to check if the next argument exists
static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
if (i + 1 < argc && argv[i + 1][0] != '-') { if (i + 1 < argc && argv[i + 1][0] != '-') {

View File

@ -283,7 +283,7 @@ static std::string set_xterm256_foreground(int r, int g, int b) {
} }
// Lowest is red, middle is yellow, highest is green. Color scheme from // Lowest is red, middle is yellow, highest is green. Color scheme from
// Paul Tol; it is colorblind friendly https://sronpersonalpages.nl/~pault // Paul Tol; it is colorblind friendly https://personal.sron.nl/~pault/
const std::vector<std::string> k_colors = { const std::vector<std::string> k_colors = {
set_xterm256_foreground(220, 5, 12), set_xterm256_foreground(220, 5, 12),
set_xterm256_foreground(232, 96, 28), set_xterm256_foreground(232, 96, 28),
@ -294,26 +294,6 @@ const std::vector<std::string> k_colors = {
set_xterm256_foreground( 78, 178, 101), set_xterm256_foreground( 78, 178, 101),
}; };
// ANSI formatting codes
static std::string set_inverse() {
return "\033[7m";
}
static std::string set_underline() {
return "\033[4m";
}
static std::string set_dim() {
return "\033[2m";
}
// Style scheme for different confidence levels
const std::vector<std::string> k_styles = {
set_inverse(), // Low confidence - inverse (highlighted)
set_underline(), // Medium confidence - underlined
set_dim(), // High confidence - dim
};
// //
// Other utils // Other utils
// //

View File

@ -194,7 +194,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
AVIOContext *avio_ctx; AVIOContext *avio_ctx;
AVStream *stream; AVStream *stream;
AVCodecContext *codec; AVCodecContext *codec;
AVPacket *packet; AVPacket packet;
AVFrame *frame; AVFrame *frame;
struct SwrContext *swr; struct SwrContext *swr;
u8 *avio_ctx_buffer; u8 *avio_ctx_buffer;
@ -249,20 +249,6 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* prepare resampler */ /* prepare resampler */
swr = swr_alloc(); swr = swr_alloc();
#if LIBAVCODEC_VERSION_MAJOR > 60
AVChannelLayout in_ch_layout = codec->ch_layout;
AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO;
/* Set the source audio layout as-is */
av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0);
av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
/* Convert it into 16khz Mono */
av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0);
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
#else
av_opt_set_int(swr, "in_channel_count", codec->channels, 0); av_opt_set_int(swr, "in_channel_count", codec->channels, 0);
av_opt_set_int(swr, "out_channel_count", 1, 0); av_opt_set_int(swr, "out_channel_count", 1, 0);
av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0);
@ -271,7 +257,6 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0);
av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0);
av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0);
#endif
swr_init(swr); swr_init(swr);
if (!swr_is_initialized(swr)) { if (!swr_is_initialized(swr)) {
@ -279,11 +264,7 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
return -1; return -1;
} }
packet=av_packet_alloc(); av_init_packet(&packet);
if (!packet) {
LOG("Error allocating the packet\n");
return -1;
}
frame = av_frame_alloc(); frame = av_frame_alloc();
if (!frame) { if (!frame) {
LOG("Error allocating the frame\n"); LOG("Error allocating the frame\n");
@ -293,8 +274,8 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* iterate through frames */ /* iterate through frames */
*data = NULL; *data = NULL;
*size = 0; *size = 0;
while (av_read_frame(fmt_ctx, packet) >= 0) { while (av_read_frame(fmt_ctx, &packet) >= 0) {
avcodec_send_packet(codec, packet); avcodec_send_packet(codec, &packet);
err = avcodec_receive_frame(codec, frame); err = avcodec_receive_frame(codec, frame);
if (err == AVERROR(EAGAIN)) if (err == AVERROR(EAGAIN))
@ -305,11 +286,10 @@ static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size)
/* Flush any remaining conversion buffers... */ /* Flush any remaining conversion buffers... */
convert_frame(swr, codec, frame, data, size, true); convert_frame(swr, codec, frame, data, size, true);
av_packet_free(&packet);
av_frame_free(&frame); av_frame_free(&frame);
swr_free(&swr); swr_free(&swr);
//avio_context_free(); // todo? //avio_context_free(); // todo?
avcodec_free_context(&codec); avcodec_close(codec);
avformat_close_input(&fmt_ctx); avformat_close_input(&fmt_ctx);
avformat_free_context(fmt_ctx); avformat_free_context(fmt_ctx);

View File

@ -2,7 +2,7 @@
# #
# Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals # Transcribe audio livestream by feeding ffmpeg output to whisper.cpp at regular intervals
# Idea by @semiformal-net # Idea by @semiformal-net
# ref: https://github.com/ggml-org/whisper.cpp/issues/185 # ref: https://github.com/ggerganov/whisper.cpp/issues/185
# #
set -eo pipefail set -eo pipefail

View File

@ -424,8 +424,6 @@ static void process_loop(struct whisper_context * ctx, audio_async &audio, const
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params; whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) { if (whisper_params_parse(argc, argv, params) == false) {
return 1; return 1;

View File

@ -1,5 +1,4 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-backend.h"
#include "common.h" #include "common.h"
#include "common-ggml.h" #include "common-ggml.h"
@ -177,8 +176,6 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
if (argc != 4) { if (argc != 4) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
ggml_print_ftypes(stderr); ggml_print_ftypes(stderr);

View File

@ -1,115 +1,39 @@
import http.server import http.server
import socketserver import socketserver
import os import os
import sys
from pathlib import Path from pathlib import Path
import urllib.parse
SCRIPT_DIR = Path(__file__).parent.absolute() SCRIPT_DIR = Path(__file__).parent.absolute()
DIRECTORY = os.path.join(SCRIPT_DIR, "../build-em/bin") DIRECTORY = os.path.join(SCRIPT_DIR, "../build-em/bin")
DIRECTORY = os.path.abspath(DIRECTORY) DIRECTORY = os.path.abspath(DIRECTORY)
# The context root we want for all applications
CONTEXT_ROOT = "/whisper.cpp"
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, directory=DIRECTORY, **kwargs) super().__init__(*args, directory=DIRECTORY, **kwargs)
def do_GET(self): def do_GET(self):
# Redirect root to the context root # If requesting a worker file from any subdirectory
if self.path == '/': if '.worker.js' in self.path:
self.send_response(302)
self.send_header('Location', CONTEXT_ROOT + '/')
self.end_headers()
return
# Handle requests under the context root
if self.path.startswith(CONTEXT_ROOT):
# Remove the context root prefix to get the actual path
actual_path = self.path[len(CONTEXT_ROOT):]
if not actual_path:
self.send_response(302)
self.send_header('Location', CONTEXT_ROOT + '/')
self.end_headers()
return
if '.worker.js' in actual_path:
worker_file = os.path.basename(actual_path)
worker_path = os.path.join(DIRECTORY, worker_file)
if os.path.exists(worker_path):
print(f"Found worker file: {worker_path}")
self.path = '/' + worker_file
else:
print(f"Worker file not found: {worker_path}")
elif actual_path == '/':
self.path = '/whisper.wasm/index.html'
elif actual_path.startswith('/bench.wasm/') or actual_path.startswith('/command.wasm/') or actual_path.startswith('/stream.wasm/'):
# Keep the path as is, just remove the context root
self.path = actual_path
# For all other paths under the context root
else:
# Check if this is a request to a file in whisper.wasm
potential_file = os.path.join(DIRECTORY, 'whisper.wasm', actual_path.lstrip('/'))
if os.path.exists(potential_file) and not os.path.isdir(potential_file):
self.path = '/whisper.wasm' + actual_path
else:
# Try to resolve the file from the base directory
potential_file = os.path.join(DIRECTORY, actual_path.lstrip('/'))
if os.path.exists(potential_file):
self.path = actual_path
# For direct requests to worker files (without context root as these
# are in the build-em/bin directory
elif '.worker.js' in self.path:
worker_file = os.path.basename(self.path) worker_file = os.path.basename(self.path)
worker_path = os.path.join(DIRECTORY, worker_file) worker_path = os.path.join(DIRECTORY, worker_file)
if os.path.exists(worker_path): if os.path.exists(worker_path):
self.path = '/' + worker_file self.path = '/' + worker_file
# Handle coi-serviceworker.js separately
if 'coi-serviceworker.js' in self.path:
worker_file = "coi-serviceworker.js"
worker_path = os.path.join(SCRIPT_DIR, worker_file)
if os.path.exists(worker_path):
self.send_response(200)
self.send_header('Content-type', 'application/javascript')
self.end_headers()
with open(worker_path, 'rb') as file:
self.wfile.write(file.read())
return
else:
print(f"Warning: Could not find {worker_path}")
return super().do_GET() return super().do_GET()
def end_headers(self): def end_headers(self):
# Add required headers for SharedArrayBuffer # Add required headers for SharedArrayBuffer
self.send_header("Cross-Origin-Opener-Policy", "same-origin") self.send_header("Cross-Origin-Opener-Policy", "same-origin")
self.send_header("Cross-Origin-Embedder-Policy", "require-corp") self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Origin", "*");
super().end_headers() super().end_headers()
PORT = 8000 PORT = 8000
# Enable address reuse with socketserver.TCPServer(("", PORT), CustomHTTPRequestHandler) as httpd:
class CustomServer(socketserver.TCPServer):
allow_reuse_address = True
try:
with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd:
print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}") print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
print(f"Application context root: http://localhost:{PORT}{CONTEXT_ROOT}/")
try: try:
httpd.serve_forever() httpd.serve_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nServer stopped.") print("\nServer stopped.")
# Force complete exit
sys.exit(0)
except OSError as e:
print(f"Error: {e}")
sys.exit(1)

View File

@ -1,6 +1,3 @@
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(TARGET whisper-server) set(TARGET whisper-server)
add_executable(${TARGET} server.cpp httplib.h) add_executable(${TARGET} server.cpp httplib.h)

View File

@ -23,7 +23,6 @@ options:
-sow, --split-on-word [false ] split on word rather than on token -sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [2 ] number of best candidates to keep -bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search -bs N, --beam-size N [-1 ] beam size for beam search
-ac N, --audio-ctx N [0 ] audio context size (0 - all)
-wt N, --word-thold N [0.01 ] word timestamp probability threshold -wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail -et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail -lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
@ -42,28 +41,9 @@ options:
--prompt PROMPT [ ] initial prompt --prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path -m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
--host HOST, [127.0.0.1] Hostname/ip-adress for the server --host HOST, [127.0.0.1] Hostname/ip-adress for the server
--port PORT, [8080 ] Port number for the server --port PORT, [8080 ] Port number for the server
--public PATH, [examples/server/public] Path to the public folder
--request-path PATH, [ ] Request path for all requests
--inference-path PATH, [/inference] Inference path for all requests
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server --convert, [false ] Convert audio to WAV, requires ffmpeg on the server
-sns, --suppress-nst [false ] suppress non-speech tokens
-nth N, --no-speech-thold N [0.60 ] no speech threshold
-nc, --no-context [false ] do not use previous audio context
-ng, --no-gpu [false ] do not use gpu
-fa, --flash-attn [false ] flash attention
Voice Activity Detection (VAD) options:
--vad [false ] enable Voice Activity Detection (VAD)
-vm FNAME, --vad-model FNAME [ ] VAD model path
-vt N, --vad-threshold N [0.50 ] VAD threshold for speech recognition
-vspd N, --vad-min-speech-duration-ms N [250 ] VAD min speech duration (0.0-1.0)
-vsd N, --vad-min-silence-duration-ms N [100 ] VAD min silence duration (to split segments)
-vmsd N, --vad-max-speech-duration-s N [FLT_MAX] VAD max speech duration (auto-split longer)
-vp N, --vad-speech-pad-ms N [30 ] VAD speech padding (extend segments)
-vo N, --vad-samples-overlap N [0.10 ] VAD samples overlap (seconds between segments)
``` ```
> [!WARNING] > [!WARNING]
@ -87,35 +67,3 @@ curl 127.0.0.1:8080/load \
-H "Content-Type: multipart/form-data" \ -H "Content-Type: multipart/form-data" \
-F model="<path-to-model-file>" -F model="<path-to-model-file>"
``` ```
## Load testing with k6
> **Note:** Install [k6](https://k6.io/docs/get-started/installation/) before running the benchmark script.
You can benchmark the Whisper server using the provided bench.js script with [k6](https://k6.io/). This script sends concurrent multipart requests to the /inference endpoint and is fully configurable via environment variables.
**Example usage:**
```
k6 run bench.js \
--env FILE_PATH=/absolute/path/to/samples/jfk.wav \
--env BASE_URL=http://127.0.0.1:8080 \
--env ENDPOINT=/inference \
--env CONCURRENCY=4 \
--env TEMPERATURE=0.0 \
--env TEMPERATURE_INC=0.2 \
--env RESPONSE_FORMAT=json
```
**Environment variables:**
- `FILE_PATH`: Path to the audio file to send (must be absolute or relative to the k6 working directory)
- `BASE_URL`: Server base URL (default: `http://127.0.0.1:8080`)
- `ENDPOINT`: API endpoint (default: `/inference`)
- `CONCURRENCY`: Number of concurrent requests (default: 4)
- `TEMPERATURE`: Decoding temperature (default: 0.0)
- `TEMPERATURE_INC`: Temperature increment (default: 0.2)
- `RESPONSE_FORMAT`: Response format (default: `json`)
**Note:**
- The server must be running and accessible at the specified `BASE_URL` and `ENDPOINT`.
- The script is located in the same directory as this README: `bench.js`.

View File

@ -1,29 +0,0 @@
import http from 'k6/http'
import { check } from 'k6'
export let options = {
vus: parseInt(__ENV.CONCURRENCY) || 4,
iterations: parseInt(__ENV.CONCURRENCY) || 4,
}
const filePath = __ENV.FILE_PATH
const baseURL = __ENV.BASE_URL || 'http://127.0.0.1:8080'
const endpoint = __ENV.ENDPOINT || '/inference'
const temperature = __ENV.TEMPERATURE || '0.0'
const temperatureInc = __ENV.TEMPERATURE_INC || '0.2'
const responseFormat = __ENV.RESPONSE_FORMAT || 'json'
// Read the file ONCE at init time
const fileBin = open(filePath, 'b')
export default function () {
const payload = {
file: http.file(fileBin, filePath),
temperature: temperature,
temperature_inc: temperatureInc,
response_format: responseFormat,
}
const res = http.post(`${baseURL}${endpoint}`, payload)
check(res, { 'status is 200': r => r.status === 200 })
}

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,6 @@
#include "httplib.h" #include "httplib.h"
#include "json.hpp" #include "json.hpp"
#include <cfloat>
#include <chrono> #include <chrono>
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
@ -14,23 +13,14 @@
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <memory>
#include <csignal> #if defined(_MSC_VER)
#include <atomic> #pragma warning(disable: 4244 4267) // possible loss of data
#include <functional>
#include <cstdlib>
#if defined (_WIN32)
#include <windows.h>
#endif #endif
using namespace httplib; using namespace httplib;
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
};
namespace { namespace {
// output formats // output formats
@ -40,20 +30,6 @@ const std::string srt_format = "srt";
const std::string vjson_format = "verbose_json"; const std::string vjson_format = "verbose_json";
const std::string vtt_format = "vtt"; const std::string vtt_format = "vtt";
std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
inline void signal_handler(int signal) {
if (is_terminating.test_and_set()) {
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
// this is for better developer experience, we can remove when the server is stable enough
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
exit(1);
}
shutdown_handler(signal);
}
struct server_params struct server_params
{ {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
@ -103,7 +79,6 @@ struct whisper_params {
bool use_gpu = true; bool use_gpu = true;
bool flash_attn = false; bool flash_attn = false;
bool suppress_nst = false; bool suppress_nst = false;
bool no_context = false;
std::string language = "en"; std::string language = "en";
std::string prompt = ""; std::string prompt = "";
@ -118,16 +93,6 @@ struct whisper_params {
std::string openvino_encode_device = "CPU"; std::string openvino_encode_device = "CPU";
std::string dtw = ""; std::string dtw = "";
// Voice Activity Detection (VAD) parameters
bool vad = false;
std::string vad_model = "";
float vad_threshold = 0.5f;
int vad_min_speech_duration_ms = 250;
int vad_min_silence_duration_ms = 100;
float vad_max_speech_duration_s = FLT_MAX;
int vad_speech_pad_ms = 30;
float vad_samples_overlap = 0.1f;
}; };
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) { void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) {
@ -175,21 +140,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
// Voice Activity Detection (VAD) parameters
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
std::string("FLT_MAX").c_str() :
std::to_string(params.vad_max_speech_duration_s).c_str());
fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -236,7 +186,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); }
else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; }
// server params // server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
@ -245,16 +194,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } else if ( arg == "--request-path") { sparams.request_path = argv[++i]; }
else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; }
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
// Voice Activity Detection (VAD)
else if ( arg == "--vad") { params.vad = true; }
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; }
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); }
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); }
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); }
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); }
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); }
else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); }
else { else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams); whisper_print_usage(argc, argv, params, sparams);
@ -567,45 +506,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
{ {
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content); params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
} }
if (req.has_file("no_context"))
{
params.no_context = parse_str_to_bool(req.get_file_value("no_context").content);
}
if (req.has_file("vad"))
{
params.vad = parse_str_to_bool(req.get_file_value("vad").content);
}
if (req.has_file("vad_threshold"))
{
params.vad_threshold = std::stof(req.get_file_value("vad_threshold").content);
}
if (req.has_file("vad_min_speech_duration_ms"))
{
params.vad_min_speech_duration_ms = std::stof(req.get_file_value("vad_min_speech_duration_ms").content);
}
if (req.has_file("vad_min_silence_duration_ms"))
{
params.vad_min_silence_duration_ms = std::stof(req.get_file_value("vad_min_silence_duration_ms").content);
}
if (req.has_file("vad_max_speech_duration_s"))
{
params.vad_max_speech_duration_s = std::stof(req.get_file_value("vad_max_speech_duration_s").content);
}
if (req.has_file("vad_speech_pad_ms"))
{
params.vad_speech_pad_ms = std::stoi(req.get_file_value("vad_speech_pad_ms").content);
}
if (req.has_file("vad_samples_overlap"))
{
params.vad_samples_overlap = std::stof(req.get_file_value("vad_samples_overlap").content);
}
} }
} // namespace } // namespace
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params; whisper_params params;
server_params sparams; server_params sparams;
@ -674,9 +579,6 @@ int main(int argc, char ** argv) {
if (params.dtw == "large.v3") { if (params.dtw == "large.v3") {
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
} }
if (params.dtw == "large.v3.turbo") {
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3_TURBO;
}
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
@ -684,9 +586,6 @@ int main(int argc, char ** argv) {
} }
} }
std::unique_ptr<httplib::Server> svr = std::make_unique<httplib::Server>();
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) { if (ctx == nullptr) {
@ -696,10 +595,9 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
state.store(SERVER_STATE_READY);
Server svr;
svr->set_default_headers({{"Server", "whisper.cpp"}, svr.set_default_headers({{"Server", "whisper.cpp"},
{"Access-Control-Allow-Origin", "*"}, {"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type, authorization"}}); {"Access-Control-Allow-Headers", "content-type, authorization"}});
@ -778,15 +676,15 @@ int main(int argc, char ** argv) {
whisper_params default_params = params; whisper_params default_params = params;
// this is only called if no index.html is found in the public --path // this is only called if no index.html is found in the public --path
svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){ svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
res.set_content(default_content, "text/html"); res.set_content(default_content, "text/html");
return false; return false;
}); });
svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){ svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
}); });
svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){ svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
// acquire whisper model mutex lock // acquire whisper model mutex lock
std::lock_guard<std::mutex> lock(whisper_mutex); std::lock_guard<std::mutex> lock(whisper_mutex);
@ -920,20 +818,9 @@ int main(int argc, char ** argv) {
wparams.no_timestamps = params.no_timestamps; wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
wparams.no_context = params.no_context;
wparams.suppress_nst = params.suppress_nst; wparams.suppress_nst = params.suppress_nst;
wparams.vad = params.vad;
wparams.vad_model_path = params.vad_model.c_str();
wparams.vad_params.threshold = params.vad_threshold;
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 }; whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment // this callback is called on each new segment
@ -947,25 +834,33 @@ int main(int argc, char ** argv) {
wparams.progress_callback_user_data = &user_data; wparams.progress_callback_user_data = &user_data;
} }
// tell whisper to abort if the HTTP connection closed // examples for abort mechanism
wparams.abort_callback = [](void *user_data) { // in examples below, we do not abort the processing, but we could if the flag is set to true
// user_data is a pointer to our Request
auto req_ptr = static_cast<const httplib::Request*>(user_data); // the callback is called before every encoder run - if it returns false, the processing is aborted
return req_ptr->is_connection_closed(); {
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
}; };
wparams.abort_callback_user_data = (void*)&req; wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
// handle failure or early abort
if (req.is_connection_closed()) {
// log client disconnect
fprintf(stderr, "client disconnected, aborted processing\n");
res.status = 499; // Client Closed Request (nginx convention)
res.set_content("{\"error\":\"client disconnected\"}", "application/json");
return;
}
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
res.status = 500; // Internal Server Error
const std::string error_resp = "{\"error\":\"failed to process audio\"}"; const std::string error_resp = "{\"error\":\"failed to process audio\"}";
res.set_content(error_resp, "application/json"); res.set_content(error_resp, "application/json");
return; return;
@ -1024,25 +919,13 @@ int main(int argc, char ** argv) {
} else if (params.response_format == vjson_format) { } else if (params.response_format == vjson_format) {
/* try to match openai/whisper's Python format */ /* try to match openai/whisper's Python format */
std::string results = output_str(ctx, params, pcmf32s); std::string results = output_str(ctx, params, pcmf32s);
// Get language probabilities
std::vector<float> lang_probs(whisper_lang_max_id() + 1, 0.0f);
const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data());
json jres = json{ json jres = json{
{"task", params.translate ? "translate" : "transcribe"}, {"task", params.translate ? "translate" : "transcribe"},
{"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))},
{"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE}, {"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE},
{"text", results}, {"text", results},
{"segments", json::array()}, {"segments", json::array()}
{"detected_language", whisper_lang_str_full(detected_lang_id)},
{"detected_language_probability", lang_probs[detected_lang_id]},
{"language_probabilities", json::object()}
}; };
// Add all language probabilities
for (int i = 0; i <= whisper_lang_max_id(); ++i) {
if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities
jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i];
}
}
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) for (int i = 0; i < n_segments; ++i)
{ {
@ -1102,9 +985,8 @@ int main(int argc, char ** argv) {
// reset params to their defaults // reset params to their defaults
params = default_params; params = default_params;
}); });
svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
std::lock_guard<std::mutex> lock(whisper_mutex); std::lock_guard<std::mutex> lock(whisper_mutex);
state.store(SERVER_STATE_LOADING_MODEL);
if (!req.has_file("model")) if (!req.has_file("model"))
{ {
fprintf(stderr, "error: no 'model' field in the request\n"); fprintf(stderr, "error: no 'model' field in the request\n");
@ -1136,25 +1018,18 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
state.store(SERVER_STATE_READY);
const std::string success = "Load was successful!"; const std::string success = "Load was successful!";
res.set_content(success, "application/text"); res.set_content(success, "application/text");
// check if the model is in the file system // check if the model is in the file system
}); });
svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){ svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
server_state current_state = state.load();
if (current_state == SERVER_STATE_READY) {
const std::string health_response = "{\"status\":\"ok\"}"; const std::string health_response = "{\"status\":\"ok\"}";
res.set_content(health_response, "application/json"); res.set_content(health_response, "application/json");
} else {
res.set_content("{\"status\":\"loading model\"}", "application/json");
res.status = 503;
}
}); });
svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) { svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s"; const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ]; char buf[BUFSIZ];
try { try {
@ -1168,7 +1043,7 @@ int main(int argc, char ** argv) {
res.status = 500; res.status = 500;
}); });
svr->set_error_handler([](const Request &req, Response &res) { svr.set_error_handler([](const Request &req, Response &res) {
if (res.status == 400) { if (res.status == 400) {
res.set_content("Invalid request", "text/plain"); res.set_content("Invalid request", "text/plain");
} else if (res.status != 500) { } else if (res.status != 500) {
@ -1178,10 +1053,10 @@ int main(int argc, char ** argv) {
}); });
// set timeouts and change hostname and port // set timeouts and change hostname and port
svr->set_read_timeout(sparams.read_timeout); svr.set_read_timeout(sparams.read_timeout);
svr->set_write_timeout(sparams.write_timeout); svr.set_write_timeout(sparams.write_timeout);
if (!svr->bind_to_port(sparams.hostname, sparams.port)) if (!svr.bind_to_port(sparams.hostname, sparams.port))
{ {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
sparams.hostname.c_str(), sparams.port); sparams.hostname.c_str(), sparams.port);
@ -1189,50 +1064,18 @@ int main(int argc, char ** argv) {
} }
// Set the base directory for serving static files // Set the base directory for serving static files
svr->set_base_dir(sparams.public_path); svr.set_base_dir(sparams.public_path);
// to make it ctrl+clickable: // to make it ctrl+clickable:
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
shutdown_handler = [&](int signal) { if (!svr.listen_after_bind())
printf("\nCaught signal %d, shutting down gracefully...\n", signal); {
if (svr) { return 1;
svr->stop();
} }
};
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
// clean up function, to be called before exit
auto clean_up = [&]() {
whisper_print_timings(ctx); whisper_print_timings(ctx);
whisper_free(ctx); whisper_free(ctx);
};
std::thread t([&] {
if (!svr->listen_after_bind()) {
fprintf(stderr, "error: server listen failed\n");
}
});
svr->wait_until_ready();
t.join();
clean_up();
return 0; return 0;
} }

View File

@ -35,7 +35,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
-s INITIAL_MEMORY=1024MB \ -s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \ -s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \
") ")

View File

@ -26,10 +26,5 @@ to the server's HTTP path:
``` ```
# copy the produced page to your HTTP path # copy the produced page to your HTTP path
cp bin/stream.wasm/* /path/to/html/ cp bin/stream.wasm/* /path/to/html/
cp bin/libstream.js /path/to/html/
cp bin/libstream.worker.js /path/to/html/ cp bin/libstream.worker.js /path/to/html/
``` ```
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
> longer generated and the worker is embedded in the main JS file. So the worker
> file will not be geneated for versions later than `3.1.58`.

View File

@ -116,8 +116,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_backend_load_all();
whisper_params params; whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) { if (whisper_params_parse(argc, argv, params) == false) {
@ -163,10 +161,6 @@ int main(int argc, char ** argv) {
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 2;
}
std::vector<float> pcmf32 (n_samples_30s, 0.0f); std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old; std::vector<float> pcmf32_old;

View File

@ -12,18 +12,11 @@ if (WHISPER_SDL2)
llama-context.cpp llama-context.cpp
llama-cparams.cpp llama-cparams.cpp
llama-grammar.cpp llama-grammar.cpp
llama-graph.cpp
llama-hparams.cpp llama-hparams.cpp
llama-impl.cpp llama-impl.cpp
llama-io.cpp llama-kv-cache.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-memory-recurrent.cpp
llama-memory-hybrid.cpp
llama-memory.cpp
llama-mmap.cpp llama-mmap.cpp
llama-model-loader.cpp llama-model-loader.cpp
llama-model-saver.cpp
llama-model.cpp llama-model.cpp
llama-quant.cpp llama-quant.cpp
llama-sampling.cpp llama-sampling.cpp

View File

@ -4,13 +4,14 @@
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include <algorithm>
#include <map> #include <map>
#include <cassert> #include <cassert>
#include <stdexcept> #include <stdexcept>
// vec // vec
ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr; return nullptr;
} }
@ -18,7 +19,7 @@ ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
return tensors[il]; return tensors[il];
} }
ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const { struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il); ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) { if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir); cur = ggml_add(ctx, cur, layer_dir);
@ -39,7 +40,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
@ -90,7 +91,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
return true; return true;
} }
bool llama_adapter_cvec::apply( int32_t llama_adapter_cvec::apply(
const llama_model & model, const llama_model & model,
const float * data, const float * data,
size_t len, size_t len,
@ -103,17 +104,17 @@ bool llama_adapter_cvec::apply(
// disable the current control vector (but leave allocated for later) // disable the current control vector (but leave allocated for later)
layer_start = -1; layer_start = -1;
layer_end = -1; layer_end = -1;
return true; return 0;
} }
if (n_embd != (int) hparams.n_embd) { if (n_embd != (int) hparams.n_embd) {
LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
return false; return 1;
} }
if (tensors.empty()) { if (tensors.empty()) {
if (!init(model)) { if (!init(model)) {
return false; return 1;
} }
} }
@ -129,12 +130,12 @@ bool llama_adapter_cvec::apply(
} }
} }
return true; return 0;
} }
// lora // lora
llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
const std::string name(w->name); const std::string name(w->name);
const auto pos = ab_map.find(name); const auto pos = ab_map.find(name);
@ -145,11 +146,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
return nullptr; return nullptr;
} }
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init; ggml_context * ctx_init;
gguf_init_params meta_gguf_params = { struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true, /* .no_alloc = */ true,
/* .ctx = */ &ctx_init, /* .ctx = */ &ctx_init,
}; };
@ -200,7 +201,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
auto it = ctx_map.find(buft); auto it = ctx_map.find(buft);
if (it == ctx_map.end()) { if (it == ctx_map.end()) {
// add a new context // add a new context
ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(), /*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true, /*.no_alloc =*/ true,
@ -247,29 +248,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
} }
} }
// get extra buffer types of the CPU
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
std::vector<ggml_backend_buffer_type_t> buft_extra;
{
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (ggml_backend_dev_get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_extra.emplace_back(*extra_bufts);
++extra_bufts;
}
}
}
// add tensors // add tensors
for (auto & it : ab_map) { for (auto & it : ab_map) {
const std::string & name = it.first; const std::string & name = it.first;
@ -286,26 +264,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
} }
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer); struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
for (auto & ex : buft_extra) {
if (ex == buft) {
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
buft = ggml_backend_dev_buffer_type(cpu_dev);
break;
}
}
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
ggml_context * dev_ctx = ctx_for_buft(buft);
// validate tensor shape // validate tensor shape
if (is_token_embd) { if (is_token_embd) {
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@ -322,8 +281,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
} }
// save tensor to adapter // save tensor to adapter
ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name); ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@ -349,7 +308,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
{ {
llama_file gguf_file(path_lora, "rb"); llama_file gguf_file(path_lora, "rb");
std::vector<uint8_t> read_buf; std::vector<uint8_t> read_buf;
auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) { auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name)); size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
size_t size = ggml_nbytes(orig); size_t size = ggml_nbytes(orig);
read_buf.resize(size); read_buf.resize(size);
@ -368,8 +327,8 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
} }
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
llama_adapter_lora * adapter = new llama_adapter_lora(); struct llama_adapter_lora * adapter = new llama_adapter_lora();
try { try {
llama_adapter_lora_init_impl(*model, path_lora, *adapter); llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@ -383,6 +342,6 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
return nullptr; return nullptr;
} }
void llama_adapter_lora_free(llama_adapter_lora * adapter) { void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
delete adapter; delete adapter;
} }

View File

@ -15,11 +15,11 @@
// //
struct llama_adapter_cvec { struct llama_adapter_cvec {
ggml_tensor * tensor_for(int il) const; struct ggml_tensor * tensor_for(int il) const;
ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int il) const; struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
bool apply( int32_t apply(
const llama_model & model, const llama_model & model,
const float * data, const float * data,
size_t len, size_t len,
@ -36,7 +36,7 @@ private:
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<ggml_tensor *> tensors; // per layer std::vector<struct ggml_tensor *> tensors; // per layer
}; };
// //
@ -44,8 +44,8 @@ private:
// //
struct llama_adapter_lora_weight { struct llama_adapter_lora_weight {
ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha // get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) const { float get_scale(float alpha, float adapter_scale) const {
@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
} }
llama_adapter_lora_weight() = default; llama_adapter_lora_weight() = default;
llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {} llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
}; };
struct llama_adapter_lora { struct llama_adapter_lora {
// map tensor name to lora_a_b // map tensor name to lora_a_b
std::unordered_map<std::string, llama_adapter_lora_weight> ab_map; std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
@ -70,7 +70,5 @@ struct llama_adapter_lora {
llama_adapter_lora() = default; llama_adapter_lora() = default;
~llama_adapter_lora() = default; ~llama_adapter_lora() = default;
llama_adapter_lora_weight * get_weight(ggml_tensor * w); llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
}; };
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;

View File

@ -6,7 +6,6 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" }, { LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GROK, "grok" },
@ -19,8 +18,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" }, { LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_STABLELM, "stablelm" },
@ -28,8 +25,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" },
{ LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PHIMOE, "phimoe" },
@ -41,8 +36,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_XVERSE, "xverse" },
@ -57,7 +50,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5, "t5" },
{ LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_T5ENCODER, "t5encoder" },
@ -66,17 +58,10 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_ARCH_ARWKV7, "arwkv7" },
{ LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -85,7 +70,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" }, { LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" }, { LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" }, { LLM_KV_GENERAL_VERSION, "general.version" },
@ -112,7 +96,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@ -125,7 +108,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -140,16 +122,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_DECAY_LORA_RANK, "%s.attention.decay_lora_rank" },
{ LLM_KV_ATTENTION_ICLR_LORA_RANK, "%s.attention.iclr_lora_rank" },
{ LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, "%s.attention.value_residual_mix_lora_rank" },
{ LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@ -180,8 +155,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@ -200,13 +173,13 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@ -250,53 +223,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
}, },
}, },
{
LLM_ARCH_ARCEE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LLAMA4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{ {
LLM_ARCH_DECI, LLM_ARCH_DECI,
{ {
@ -474,7 +400,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
@ -501,39 +426,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_NEO_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{ {
LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V2,
{ {
@ -662,45 +554,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_QWEN3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_QWEN3MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{ {
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
{ {
@ -913,63 +766,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
}, },
}, },
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3N,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
},
},
{ {
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
{ {
@ -1203,8 +999,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@ -1221,22 +1015,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}, },
}, },
{
LLM_ARCH_PLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_CHATGLM, LLM_ARCH_CHATGLM,
{ {
@ -1255,25 +1033,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
}, },
}, },
{
LLM_ARCH_GLM4,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{ {
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
{ {
@ -1458,74 +1217,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_RWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" },
{ LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" },
{ LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" },
},
},
{
LLM_ARCH_ARWKV7,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_TIME_MIX_W0, "blk.%d.time_mix_w0" },
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
{ LLM_TENSOR_TIME_MIX_A0, "blk.%d.time_mix_a0" },
{ LLM_TENSOR_TIME_MIX_A1, "blk.%d.time_mix_a1" },
{ LLM_TENSOR_TIME_MIX_A2, "blk.%d.time_mix_a2" },
{ LLM_TENSOR_TIME_MIX_V0, "blk.%d.time_mix_v0" },
{ LLM_TENSOR_TIME_MIX_V1, "blk.%d.time_mix_v1" },
{ LLM_TENSOR_TIME_MIX_V2, "blk.%d.time_mix_v2" },
{ LLM_TENSOR_TIME_MIX_G1, "blk.%d.time_mix_g1" },
{ LLM_TENSOR_TIME_MIX_G2, "blk.%d.time_mix_g2" },
{ LLM_TENSOR_TIME_MIX_K_K, "blk.%d.time_mix_k_k" },
{ LLM_TENSOR_TIME_MIX_K_A, "blk.%d.time_mix_k_a" },
{ LLM_TENSOR_TIME_MIX_R_K, "blk.%d.time_mix_r_k" },
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
{ LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" },
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
{ {
@ -1559,9 +1250,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{ {
@ -1608,74 +1296,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
}, },
}, },
{
LLM_ARCH_BAILINGMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_DOTS1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_ERNIE4_5,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -1713,8 +1333,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1741,12 +1376,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_V2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_G2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -1765,9 +1394,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_K_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_R_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
@ -1775,9 +1401,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_W0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_A0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_V0, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@ -1804,23 +1427,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used // this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
@ -1844,14 +1450,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
std::string LLM_KV::operator()(llm_kv kv) const { std::string LLM_KV::operator()(llm_kv kv) const {
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
if (suffix != nullptr) {
name += ".";
name += suffix;
}
return name;
} }
std::string LLM_TN_IMPL::str() const { std::string LLM_TN_IMPL::str() const {
@ -1890,25 +1490,3 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor); return LLM_TENSOR_INFOS.at(tensor);
} }
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}

View File

@ -10,7 +10,6 @@
enum llm_arch { enum llm_arch {
LLM_ARCH_LLAMA, LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4,
LLM_ARCH_DECI, LLM_ARCH_DECI,
LLM_ARCH_FALCON, LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN, LLM_ARCH_BAICHUAN,
@ -23,8 +22,6 @@ enum llm_arch {
LLM_ARCH_REFACT, LLM_ARCH_REFACT,
LLM_ARCH_BERT, LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM, LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM, LLM_ARCH_STABLELM,
@ -32,8 +29,6 @@ enum llm_arch {
LLM_ARCH_QWEN2, LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE, LLM_ARCH_QWEN2MOE,
LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
LLM_ARCH_PHI3, LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE, LLM_ARCH_PHIMOE,
@ -45,8 +40,6 @@ enum llm_arch {
LLM_ARCH_MINICPM3, LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA, LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2, LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA, LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE, LLM_ARCH_XVERSE,
@ -61,7 +54,6 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM, LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_BITNET, LLM_ARCH_BITNET,
LLM_ARCH_T5, LLM_ARCH_T5,
LLM_ARCH_T5ENCODER, LLM_ARCH_T5ENCODER,
@ -70,17 +62,10 @@ enum llm_arch {
LLM_ARCH_EXAONE, LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6, LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
LLM_ARCH_ARWKV7,
LLM_ARCH_GRANITE, LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON, LLM_ARCH_CHAMELEON,
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -89,7 +74,6 @@ enum llm_kv {
LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_ARCHITECTURE,
LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT, LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_NAME, LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR, LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION, LLM_KV_GENERAL_VERSION,
@ -116,7 +100,6 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC, LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE, LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE, LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_DECODER_START_TOKEN_ID,
@ -129,7 +112,6 @@ enum llm_kv {
LLM_KV_RESIDUAL_SCALE, LLM_KV_RESIDUAL_SCALE,
LLM_KV_EMBEDDING_SCALE, LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -144,16 +126,9 @@ enum llm_kv {
LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_CAUSAL,
LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_DECAY_LORA_RANK,
LLM_KV_ATTENTION_ICLR_LORA_RANK,
LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK,
LLM_KV_ATTENTION_GATE_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,
LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_DIMENSION_SECTIONS,
@ -196,13 +171,13 @@ enum llm_kv {
LLM_KV_TOKENIZER_MASK_ID, LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX, LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE, LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_MID_ID,
@ -219,8 +194,6 @@ enum llm_kv {
LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
LLM_KV_CONVNEXT_BLOCK_COUNT, LLM_KV_CONVNEXT_BLOCK_COUNT,
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
// deprecated: // deprecated:
LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID,
@ -269,24 +242,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
LLM_TENSOR_ALTUP_PROJ, // gemma3n
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
LLM_TENSOR_LAUREL_L, // gemma3n
LLM_TENSOR_LAUREL_R, // gemma3n
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_X,
@ -294,20 +249,8 @@ enum llm_tensor {
LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2, LLM_TENSOR_TIME_MIX_W2,
LLM_TENSOR_TIME_MIX_A0,
LLM_TENSOR_TIME_MIX_A1,
LLM_TENSOR_TIME_MIX_A2,
LLM_TENSOR_TIME_MIX_V0,
LLM_TENSOR_TIME_MIX_V1,
LLM_TENSOR_TIME_MIX_V2,
LLM_TENSOR_TIME_MIX_G1,
LLM_TENSOR_TIME_MIX_G2,
LLM_TENSOR_TIME_MIX_K_K,
LLM_TENSOR_TIME_MIX_K_A,
LLM_TENSOR_TIME_MIX_R_K,
LLM_TENSOR_TIME_MIX_LERP_X, LLM_TENSOR_TIME_MIX_LERP_X,
LLM_TENSOR_TIME_MIX_LERP_W, LLM_TENSOR_TIME_MIX_LERP_W,
LLM_TENSOR_TIME_MIX_LERP_K, LLM_TENSOR_TIME_MIX_LERP_K,
@ -334,8 +277,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_ATTN_SUB_NORM,
@ -459,6 +400,3 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name); llm_arch llm_arch_from_string(const std::string & name);
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More