mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-07-08 16:46:55 +02:00
Compare commits
10 Commits
master
...
sync-ggml-
Author | SHA1 | Date | |
---|---|---|---|
0055356fbc | |||
eeaa1cd035 | |||
a652c8bf72 | |||
0630539c8a | |||
a7988d76db | |||
37ac0264ef | |||
5a9ccde7da | |||
cde0e50536 | |||
df458380d6 | |||
87b88ed01c |
@ -16,7 +16,6 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
# Ref: https://stackoverflow.com/a/53464012
|
||||
@ -27,12 +26,6 @@ COPY .. .
|
||||
# Enable cuBLAS
|
||||
RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1"
|
||||
|
||||
RUN find /app/build -name "*.o" -delete && \
|
||||
find /app/build -name "*.a" -delete && \
|
||||
rm -rf /app/build/CMakeFiles && \
|
||||
rm -rf /app/build/cmake_install.cmake && \
|
||||
rm -rf /app/build/_deps
|
||||
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
|
||||
ENV CUDA_MAIN_VERSION=12.3
|
||||
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
|
||||
@ -40,11 +33,8 @@ WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg wget cmake git \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
RUN du -sh /app/*
|
||||
RUN find /app -type f -size +100M
|
||||
ENV PATH=/app/build/bin:$PATH
|
||||
ENTRYPOINT [ "bash", "-c" ]
|
||||
|
@ -1,28 +0,0 @@
|
||||
ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04
|
||||
|
||||
FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY .. .
|
||||
# Enable SYCL
|
||||
ARG GGML_SYCL_F16=OFF
|
||||
RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||
echo "GGML_SYCL_F16 is set" \
|
||||
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \
|
||||
fi && \
|
||||
make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}"
|
||||
|
||||
FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg libsdl2-dev wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
ENV PATH=/app/build/bin:$PATH
|
||||
ENTRYPOINT [ "bash", "-c" ]
|
@ -1,40 +1,29 @@
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG MUSA_VERSION=rc4.0.1
|
||||
ARG MUSA_VERSION=rc3.1.1
|
||||
# Target the MUSA build image
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION}
|
||||
ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
# Target the MUSA runtime image
|
||||
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION}
|
||||
ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential libsdl2-dev wget cmake git && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* /tmp/* /var/tmp/*
|
||||
apt-get install -y build-essential libsdl2-dev wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY .. .
|
||||
# Enable muBLAS
|
||||
RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1"
|
||||
|
||||
RUN find /app/build -name "*.o" -delete && \
|
||||
find /app/build -name "*.a" -delete && \
|
||||
rm -rf /app/build/CMakeFiles && \
|
||||
rm -rf /app/build/cmake_install.cmake && \
|
||||
rm -rf /app/build/_deps
|
||||
|
||||
FROM ${BASE_MUSA_RUN_CONTAINER} AS runtime
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl ffmpeg wget cmake git && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* /tmp/* /var/tmp/*
|
||||
|
||||
COPY --from=build /app/build/bin /app/build/bin
|
||||
COPY --from=build /app/samples /app/samples
|
||||
COPY --from=build /app/models /app/models
|
||||
apt-get install -y curl ffmpeg wget cmake git \
|
||||
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
|
||||
|
||||
COPY --from=build /app /app
|
||||
ENV PATH=/app/build/bin:$PATH
|
||||
ENTRYPOINT [ "bash", "-c" ]
|
||||
|
204
.github/workflows/build.yml
vendored
204
.github/workflows/build.yml
vendored
@ -4,8 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
tags:
|
||||
- 'v*'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
workflow_dispatch:
|
||||
@ -43,7 +41,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
tag_name: ${{ steps.tag.outputs.name }}
|
||||
should_release: ${{ steps.tag.outputs.should_release }}
|
||||
|
||||
steps:
|
||||
- name: Checkout with full history
|
||||
@ -58,7 +55,6 @@ jobs:
|
||||
BUILD_NUMBER=$(git rev-list --count HEAD)
|
||||
SHORT_HASH=$(git rev-parse --short=7 HEAD)
|
||||
CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}"
|
||||
SHOULD_RELEASE="false"
|
||||
|
||||
echo "Raw values:"
|
||||
echo "BUILD_NUMBER: $BUILD_NUMBER"
|
||||
@ -66,34 +62,21 @@ jobs:
|
||||
echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}"
|
||||
echo "CUSTOM_TAG: $CUSTOM_TAG"
|
||||
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
echo "Using pushed tag name"
|
||||
TAG_NAME="${{ github.ref_name }}"
|
||||
SHOULD_RELEASE="true"
|
||||
elif [[ -n "$CUSTOM_TAG" ]]; then
|
||||
# Use custom tag if provided
|
||||
if [[ -n "$CUSTOM_TAG" ]]; then
|
||||
echo "Using custom tag"
|
||||
TAG_NAME="${CUSTOM_TAG}"
|
||||
SHOULD_RELEASE="true"
|
||||
elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then
|
||||
echo "Manual release requested"
|
||||
SHOULD_RELEASE="true"
|
||||
TAG_NAME="b${BUILD_NUMBER}"
|
||||
elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
|
||||
echo "Using master branch format"
|
||||
TAG_NAME="b${BUILD_NUMBER}"
|
||||
SHOULD_RELEASE="false"
|
||||
else
|
||||
echo "Using non-master branch format"
|
||||
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
|
||||
TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}"
|
||||
SHOULD_RELEASE="false"
|
||||
fi
|
||||
|
||||
echo "Final tag name: $TAG_NAME"
|
||||
echo "Should release: $SHOULD_RELEASE"
|
||||
echo "name=$TAG_NAME" >> $GITHUB_OUTPUT
|
||||
echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
ubuntu-22:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
@ -118,10 +101,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev cmake git
|
||||
cmake -B build
|
||||
@ -150,14 +129,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates
|
||||
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev cmake git
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
|
||||
@ -186,14 +157,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates
|
||||
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev cmake git
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
|
||||
@ -279,10 +242,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
@ -313,14 +272,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates
|
||||
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
|
||||
@ -351,14 +302,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates
|
||||
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
|
||||
@ -392,14 +335,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt-get update
|
||||
apt-get install -y ca-certificates
|
||||
sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y clang build-essential cmake libsdl2-dev git
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
@ -430,10 +365,6 @@ jobs:
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
|
||||
|
||||
apt update
|
||||
apt install -y build-essential cmake git
|
||||
cmake . -DCMAKE_BUILD_TYPE=Debug \
|
||||
@ -596,7 +527,6 @@ jobs:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
github.event.inputs.run_type == 'full-ci' }}
|
||||
runs-on: windows-latest
|
||||
needs: determine-tag
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@ -674,17 +604,12 @@ jobs:
|
||||
name: ggml_cpu_${{ matrix.arch }}.dll
|
||||
path: build/bin/${{ matrix.build }}/ggml-cpu.dll
|
||||
|
||||
- name: Pack bin artifacts
|
||||
shell: pwsh
|
||||
run: |
|
||||
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip"
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }}
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}.zip
|
||||
path: whisper-bin-${{ matrix.arch }}.zip
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-blas:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
@ -697,14 +622,11 @@ jobs:
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
blasver: [0.3.29]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
blasfile: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
blasfile: x64_64
|
||||
- sdl2: ON
|
||||
s2ver: 2.28.5
|
||||
|
||||
@ -725,8 +647,7 @@ jobs:
|
||||
- name: Install OpenBLAS and pkgconfiglite
|
||||
if: matrix.blas == 'ON'
|
||||
run: |
|
||||
Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip"
|
||||
Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}"
|
||||
vcpkg install --triplet=${{ matrix.s2arc }}-windows openblas
|
||||
choco install pkgconfiglite
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
@ -743,8 +664,6 @@ jobs:
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DGGML_BLAS=${{ matrix.blas }}
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS
|
||||
-DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib"
|
||||
-DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include"
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
@ -754,37 +673,30 @@ jobs:
|
||||
|
||||
- name: Copy openblas.dll
|
||||
if: matrix.blas == 'ON'
|
||||
run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }}
|
||||
run: copy "C:/vcpkg/packages/openblas_${{ matrix.s2arc }}-windows/bin/openblas.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Pack bin artifacts
|
||||
shell: pwsh
|
||||
run: |
|
||||
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip"
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }}
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: whisper-blas-bin-${{ matrix.arch }}.zip
|
||||
path: whisper-blas-bin-${{ matrix.arch }}.zip
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-cublas:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
github.event.inputs.run_type == 'full-ci' }}
|
||||
runs-on: windows-2022
|
||||
needs: determine-tag
|
||||
runs-on: windows-2019
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [x64]
|
||||
cublas: [ON]
|
||||
sdl2: [ON]
|
||||
cuda-toolkit: [12.4.0, 11.8.0]
|
||||
cuda-toolkit: [12.2.0, 11.8.0]
|
||||
include:
|
||||
- arch: x64
|
||||
sdl2: ON
|
||||
@ -852,7 +764,7 @@ jobs:
|
||||
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
|
||||
|
||||
# Visual Studio integration
|
||||
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y
|
||||
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
|
||||
|
||||
# Set environment variables
|
||||
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
@ -860,23 +772,23 @@ jobs:
|
||||
echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8
|
||||
|
||||
- name: Install Cuda Toolkit 12.4.0
|
||||
if: ${{ matrix.cuda-toolkit == '12.4.0' }}
|
||||
- name: Install Cuda Toolkit 12.2.0
|
||||
if: ${{ matrix.cuda-toolkit == '12.2.0' }}
|
||||
run: |
|
||||
$CUDA_VERSION = ${{ matrix.cuda-toolkit }}
|
||||
$CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION"
|
||||
$CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist"
|
||||
|
||||
# Components versions
|
||||
$CUDART_VER = "12.4.127"
|
||||
$NVCC_VER = "12.4.131"
|
||||
$NVRTC_VER = "12.4.127"
|
||||
$CUBLAS_VER = "12.4.5.8"
|
||||
$NVTX_VER = "12.4.127"
|
||||
$PROFILER_VER = "12.4.127"
|
||||
$VS_VER = "12.4.127"
|
||||
$NVPROF_VER = "12.4.128"
|
||||
$CCCL_VER = "12.4.127"
|
||||
$CUDART_VER = "12.2.140"
|
||||
$NVCC_VER = "12.2.140"
|
||||
$NVRTC_VER = "12.2.140"
|
||||
$CUBLAS_VER = "12.2.5.6"
|
||||
$NVTX_VER = "12.2.140"
|
||||
$PROFILER_VER = "12.2.140"
|
||||
$VS_VER = "12.2.140"
|
||||
$NVPROF_VER = "12.2.142"
|
||||
$CCCL_VER = "12.2.140"
|
||||
|
||||
# Create the directory where the CUDA Toolkit will be installed
|
||||
mkdir -p $CUDA_TOOLKIT_DIR
|
||||
@ -910,7 +822,7 @@ jobs:
|
||||
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y
|
||||
|
||||
# Visual Studio integration
|
||||
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y
|
||||
xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Microsoft\VC\v160\BuildCustomizations" /E /I /H /Y
|
||||
|
||||
# Set environment variables
|
||||
echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
@ -938,21 +850,14 @@ jobs:
|
||||
- name: Build Project
|
||||
shell: cmd
|
||||
run: |
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
cmake --version
|
||||
where cmake
|
||||
if "${{ matrix.cuda-toolkit }}" == "11.8.0" (
|
||||
set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR
|
||||
) else (
|
||||
set CUDA_FLAGS=
|
||||
)
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} ^
|
||||
-DGGML_CUDA=${{ matrix.cublas }} ^
|
||||
-DWHISPER_SDL2=${{ matrix.sdl2 }} ^
|
||||
-DSDL2_DIR="%SDL2_DIR%" ^
|
||||
-DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^
|
||||
-DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%"
|
||||
-DSDL2_DIR="%SDL2_DIR%"
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS%
|
||||
|
||||
@ -969,17 +874,11 @@ jobs:
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Pack bin artifacts
|
||||
shell: pwsh
|
||||
run: |
|
||||
Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip"
|
||||
|
||||
- name: Upload binaries
|
||||
if: ${{ needs.determine-tag.outputs.should_release }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip
|
||||
path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip
|
||||
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
emscripten:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
@ -1052,11 +951,16 @@ jobs:
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
|
||||
github.event.inputs.create_release == 'true' ||
|
||||
github.event.inputs.pre_release_tag != '' }}
|
||||
run: |
|
||||
zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ needs.determine-tag.outputs.should_release }}
|
||||
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
|
||||
github.event.inputs.create_release == 'true' ||
|
||||
github.event.inputs.pre_release_tag != '' }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip
|
||||
@ -1194,16 +1098,11 @@ jobs:
|
||||
chmod +x ./gradlew
|
||||
./gradlew build --info
|
||||
|
||||
- name: Pack jar artifacts
|
||||
shell: pwsh
|
||||
run: |
|
||||
Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip"
|
||||
|
||||
- name: Upload jar
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: whispercpp.jar.zip
|
||||
path: whispercpp.jar.zip
|
||||
name: whispercpp.jar
|
||||
path: bindings/java/build/libs/whispercpp-*.jar
|
||||
|
||||
# - name: Publish package
|
||||
# if: ${{ github.ref == 'refs/heads/master' }}
|
||||
@ -1234,16 +1133,13 @@ jobs:
|
||||
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
|
||||
|
||||
release:
|
||||
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }}
|
||||
if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
needs:
|
||||
- determine-tag
|
||||
- ios-xcode-build
|
||||
- windows
|
||||
- windows-blas
|
||||
- windows-cublas
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@ -1277,7 +1173,6 @@ jobs:
|
||||
with:
|
||||
tag_name: ${{ needs.determine-tag.outputs.tag_name }}
|
||||
prerelease: ${{ github.event.inputs.pre_release_tag != '' }}
|
||||
draft: true
|
||||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
@ -1304,8 +1199,7 @@ jobs:
|
||||
coreml-base-en:
|
||||
if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') ||
|
||||
github.event.inputs.create_release == 'true' ||
|
||||
github.event.inputs.pre_release_tag != '' ||
|
||||
startsWith(github.ref, 'refs/tags/v') }}
|
||||
github.event.inputs.pre_release_tag != '' }}
|
||||
runs-on: macos-latest
|
||||
needs: determine-tag
|
||||
|
||||
@ -1329,23 +1223,3 @@ jobs:
|
||||
source venv/bin/activate
|
||||
pip install ane_transformers openai-whisper coremltools
|
||||
./models/generate-coreml-model.sh ${{ env.MODEL_NAME }}
|
||||
|
||||
vad:
|
||||
if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' ||
|
||||
github.event.inputs.run_type == 'full-ci' }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Build
|
||||
shell: bash
|
||||
run: |
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
|
||||
- name: Test
|
||||
shell: bash
|
||||
run: |
|
||||
ctest -R ^test-vad$ --test-dir build --output-on-failure -VV
|
||||
|
42
.github/workflows/docker.yml
vendored
42
.github/workflows/docker.yml
vendored
@ -15,13 +15,13 @@ jobs:
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" }
|
||||
- { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" }
|
||||
- { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" }
|
||||
- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
||||
#TODO: the cuda image keeps failing - disable for now
|
||||
# https://github.com/ggerganov/whisper.cpp/actions/runs/11019444428/job/30602020339
|
||||
#- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
|
||||
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
@ -42,35 +42,21 @@ jobs:
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
sudo apt-get remove -y '^dotnet-.*' '^llvm-.*' '^mysql-.*' '^postgresql-.*'
|
||||
sudo apt-get autoremove -y
|
||||
sudo apt-get autoclean
|
||||
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
|
||||
docker system prune -af
|
||||
|
||||
df -h
|
||||
|
||||
- name: Generate tags
|
||||
id: tags
|
||||
run: |
|
||||
TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
|
||||
if [ "${{ github.event_name }}" == "push" ]; then
|
||||
TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
|
||||
fi
|
||||
echo "tags=$TAGS" >> $GITHUB_OUTPUT
|
||||
- name: Build and push Docker image (versioned)
|
||||
if: github.event_name == 'push'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: ${{ matrix.config.platform }}
|
||||
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
|
||||
- name: Build and push Docker image (tagged)
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
push: ${{ github.event_name == 'push' }}
|
||||
platforms: ${{ matrix.config.platform }}
|
||||
tags: ${{ steps.tags.outputs.tags }}
|
||||
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
|
||||
file: ${{ matrix.config.dockerfile }}
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,7 +14,6 @@
|
||||
|
||||
build/
|
||||
build-*/
|
||||
build_*/
|
||||
|
||||
# SPM
|
||||
.build/
|
||||
@ -50,8 +49,6 @@ extra/bench-gg.txt
|
||||
models/*.mlmodel
|
||||
models/*.mlmodelc
|
||||
models/*.mlpackage
|
||||
models/*-encoder-openvino.xml
|
||||
models/*-encoder-openvino-cache/
|
||||
bindings/java/.gradle/
|
||||
bindings/java/.idea/
|
||||
.idea/
|
||||
|
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
|
||||
project("whisper.cpp" C CXX)
|
||||
project("whisper.cpp" VERSION 1.7.6)
|
||||
project("whisper.cpp" VERSION 1.7.5)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(SOVERSION 1)
|
||||
@ -59,6 +59,9 @@ option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
# option list
|
||||
#
|
||||
|
||||
# general
|
||||
option(WHISPER_CCACHE "whisper: use ccache if available" ON)
|
||||
|
||||
# debug
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
option(WHISPER_ALL_WARNINGS_3RD_PARTY "whisper: enable all compiler warnings in 3rd party libs" OFF)
|
||||
@ -93,6 +96,7 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||
|
||||
# override ggml options
|
||||
set(GGML_CCACHE ${WHISPER_CCACHE})
|
||||
set(GGML_SANITIZE_THREAD ${WHISPER_SANITIZE_THREAD})
|
||||
set(GGML_SANITIZE_ADDRESS ${WHISPER_SANITIZE_ADDRESS})
|
||||
set(GGML_SANITIZE_UNDEFINED ${WHISPER_SANITIZE_UNDEFINED})
|
||||
@ -117,12 +121,6 @@ whisper_option_depr(WARNING WHISPER_OPENMP GGML_OPENMP)
|
||||
whisper_option_depr(WARNING WHISPER_RPC GGML_RPC)
|
||||
whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL)
|
||||
whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16)
|
||||
whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE)
|
||||
|
||||
if (GGML_CUDA AND NOT MSVC)
|
||||
#GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets
|
||||
add_compile_options(-Wno-deprecated-gpu-targets)
|
||||
endif()
|
||||
|
||||
#
|
||||
# build the library
|
||||
@ -178,10 +176,6 @@ get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS)
|
||||
set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h)
|
||||
install(TARGETS whisper LIBRARY PUBLIC_HEADER)
|
||||
|
||||
target_compile_definitions(whisper PRIVATE
|
||||
WHISPER_VERSION="${PROJECT_VERSION}"
|
||||
)
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake
|
||||
@ -250,6 +244,5 @@ if (MSVC)
|
||||
disable_msvc_warnings(whisper-talk-llama)
|
||||
disable_msvc_warnings(whisper-bench)
|
||||
disable_msvc_warnings(quantize)
|
||||
disable_msvc_warnings(vad-speech-segments)
|
||||
endif()
|
||||
endif()
|
||||
|
106
README.md
106
README.md
@ -7,7 +7,7 @@
|
||||
[](https://conan.io/center/whisper-cpp)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.7.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
|
||||
Stable: [v1.7.5](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.7.5) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
@ -25,7 +25,6 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- [Ascend NPU Support](#ascend-npu-support)
|
||||
- [Moore Threads GPU Support](#moore-threads-gpu-support)
|
||||
- [C-style API](https://github.com/ggml-org/whisper.cpp/blob/master/include/whisper.h)
|
||||
- [Voice Activity Detection (VAD)](#voice-activity-detection-vad)
|
||||
|
||||
Supported platforms:
|
||||
|
||||
@ -35,7 +34,7 @@ Supported platforms:
|
||||
- [x] [Java](bindings/java/README.md)
|
||||
- [x] Linux / [FreeBSD](https://github.com/ggml-org/whisper.cpp/issues/56#issuecomment-1350920264)
|
||||
- [x] [WebAssembly](examples/whisper.wasm)
|
||||
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168))
|
||||
- [x] Windows ([MSVC](https://github.com/ggml-org/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggml-org/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggml-org/whisper.cpp/discussions/166)
|
||||
- [x] [Docker](https://github.com/ggml-org/whisper.cpp/pkgs/container/whisper.cpp)
|
||||
|
||||
@ -80,7 +79,7 @@ Now build the [whisper-cli](examples/cli) example and transcribe an audio file l
|
||||
```bash
|
||||
# build the project
|
||||
cmake -B build
|
||||
cmake --build build -j --config Release
|
||||
cmake --build build --config Release
|
||||
|
||||
# transcribe an audio file
|
||||
./build/bin/whisper-cli -f samples/jfk.wav
|
||||
@ -149,7 +148,7 @@ standard cmake setup with:
|
||||
```bash
|
||||
# build with GGML_BLAS defined
|
||||
cmake -B build -DGGML_BLAS=1
|
||||
cmake --build build -j --config Release
|
||||
cmake --build build --config Release
|
||||
./build/bin/whisper-cli [ .. etc .. ]
|
||||
```
|
||||
|
||||
@ -163,7 +162,7 @@ Here are the steps for creating and using a quantized model:
|
||||
```bash
|
||||
# quantize a model with Q5_0 method
|
||||
cmake -B build
|
||||
cmake --build build -j --config Release
|
||||
cmake --build build --config Release
|
||||
./build/bin/quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
|
||||
|
||||
# run the examples as usual, specifying the quantized model file
|
||||
@ -267,7 +266,7 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
|
||||
- Build `whisper.cpp` with OpenVINO support:
|
||||
|
||||
Download OpenVINO package from [release page](https://github.com/openvinotoolkit/openvino/releases). The recommended version to use is [2024.6.0](https://github.com/openvinotoolkit/openvino/releases/tag/2024.6.0). Ready to use Binaries of the required libraries can be found in the [OpenVino Archives](https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.6/)
|
||||
Download OpenVINO package from [release page](https://github.com/openvinotoolkit/openvino/releases). The recommended version to use is [2023.0.0](https://github.com/openvinotoolkit/openvino/releases/tag/2023.0.0).
|
||||
|
||||
After downloading & extracting package onto your development system, set up required environment by sourcing setupvars script. For example:
|
||||
|
||||
@ -386,7 +385,7 @@ Run the inference examples as usual, for example:
|
||||
## Moore Threads GPU support
|
||||
|
||||
With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels.
|
||||
First, make sure you have installed `MUSA SDK rc4.0.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=4.0.1
|
||||
First, make sure you have installed `MUSA SDK rc3.1.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=rc3.1.1
|
||||
|
||||
Now build `whisper.cpp` with MUSA support:
|
||||
|
||||
@ -489,7 +488,7 @@ You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installe
|
||||
|
||||
```bash
|
||||
cmake -B build -DWHISPER_SDL2=ON
|
||||
cmake --build build -j --config Release
|
||||
cmake --build build --config Release
|
||||
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
|
||||
```
|
||||
|
||||
@ -600,7 +599,7 @@ main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 pr
|
||||
## Karaoke-style movie generation (experimental)
|
||||
|
||||
The [whisper-cli](examples/cli) example provides support for output of karaoke-style movies, where the
|
||||
currently pronounced word is highlighted. Use the `-owts` argument and run the generated bash script.
|
||||
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
|
||||
This requires to have `ffmpeg` installed.
|
||||
|
||||
Here are a few _"typical"_ examples:
|
||||
@ -709,9 +708,7 @@ For more details, see the conversion script [models/convert-pt-to-ggml.py](model
|
||||
## XCFramework
|
||||
The XCFramework is a precompiled version of the library for iOS, visionOS, tvOS,
|
||||
and macOS. It can be used in Swift projects without the need to compile the
|
||||
library from source. For example, the v1.7.5 version of the XCFramework can be
|
||||
used as follows:
|
||||
|
||||
library from source. For examples:
|
||||
```swift
|
||||
// swift-tools-version: 5.10
|
||||
// The swift-tools-version declares the minimum version of Swift required to build this package.
|
||||
@ -735,89 +732,6 @@ let package = Package(
|
||||
)
|
||||
```
|
||||
|
||||
## Voice Activity Detection (VAD)
|
||||
Support for Voice Activity Detection (VAD) can be enabled using the `--vad`
|
||||
argument to `whisper-cli`. In addition to this option a VAD model is also
|
||||
required.
|
||||
|
||||
The way this works is that first the audio samples are passed through
|
||||
the VAD model which will detect speech segments. Using this information the
|
||||
only the speech segments that are detected are extracted from the original audio
|
||||
input and passed to whisper for processing. This reduces the amount of audio
|
||||
data that needs to be processed by whisper and can significantly speed up the
|
||||
transcription process.
|
||||
|
||||
The following VAD models are currently supported:
|
||||
|
||||
### Silero-VAD
|
||||
[Silero-vad](https://github.com/snakers4/silero-vad) is a lightweight VAD model
|
||||
written in Python that is fast and accurate.
|
||||
|
||||
Models can be downloaded by running the following command on Linux or MacOS:
|
||||
```console
|
||||
$ ./models/download-vad-model.sh silero-v5.1.2
|
||||
Downloading ggml model silero-v5.1.2 from 'https://huggingface.co/ggml-org/whisper-vad' ...
|
||||
ggml-silero-v5.1.2.bin 100%[==============================================>] 864.35K --.-KB/s in 0.04s
|
||||
Done! Model 'silero-v5.1.2' saved in '/path/models/ggml-silero-v5.1.2.bin'
|
||||
You can now use it like this:
|
||||
|
||||
$ ./build/bin/whisper-cli -vm /path/models/ggml-silero-v5.1.2.bin --vad -f samples/jfk.wav -m models/ggml-base.en.bin
|
||||
|
||||
```
|
||||
And the following command on Windows:
|
||||
```console
|
||||
> .\models\download-vad-model.cmd silero-v5.1.2
|
||||
Downloading vad model silero-v5.1.2...
|
||||
Done! Model silero-v5.1.2 saved in C:\Users\danie\work\ai\whisper.cpp\ggml-silero-v5.1.2.bin
|
||||
You can now use it like this:
|
||||
|
||||
C:\path\build\bin\Release\whisper-cli.exe -vm C:\path\ggml-silero-v5.1.2.bin --vad -m models/ggml-base.en.bin -f samples\jfk.wav
|
||||
|
||||
```
|
||||
|
||||
To see a list of all available models, run the above commands without any
|
||||
arguments.
|
||||
|
||||
This model can be also be converted manually to ggml using the following command:
|
||||
```console
|
||||
$ python3 -m venv venv && source venv/bin/activate
|
||||
$ (venv) pip install silero-vad
|
||||
$ (venv) $ python models/convert-silero-vad-to-ggml.py --output models/silero.bin
|
||||
Saving GGML Silero-VAD model to models/silero-v5.1.2-ggml.bin
|
||||
```
|
||||
And it can then be used with whisper as follows:
|
||||
```console
|
||||
$ ./build/bin/whisper-cli \
|
||||
--file ./samples/jfk.wav \
|
||||
--model ./models/ggml-base.en.bin \
|
||||
--vad \
|
||||
--vad-model ./models/silero-v5.1.2-ggml.bin
|
||||
```
|
||||
|
||||
### VAD Options
|
||||
|
||||
* --vad-threshold: Threshold probability for speech detection. A probability
|
||||
for a speech segment/frame above this threshold will be considered as speech.
|
||||
|
||||
* --vad-min-speech-duration-ms: Minimum speech duration in milliseconds. Speech
|
||||
segments shorter than this value will be discarded to filter out brief noise or
|
||||
false positives.
|
||||
|
||||
* --vad-min-silence-duration-ms: Minimum silence duration in milliseconds. Silence
|
||||
periods must be at least this long to end a speech segment. Shorter silence
|
||||
periods will be ignored and included as part of the speech.
|
||||
|
||||
* --vad-max-speech-duration-s: Maximum speech duration in seconds. Speech segments
|
||||
longer than this will be automatically split into multiple segments at silence
|
||||
points exceeding 98ms to prevent excessively long segments.
|
||||
|
||||
* --vad-speech-pad-ms: Speech padding in milliseconds. Adds this amount of padding
|
||||
before and after each detected speech segment to avoid cutting off speech edges.
|
||||
|
||||
* --vad-samples-overlap: Amount of audio to extend from each speech segment into
|
||||
the next one, in seconds (e.g., 0.10 = 100ms overlap). This ensures speech isn't
|
||||
cut off abruptly between segments when they're concatenated together.
|
||||
|
||||
## Examples
|
||||
|
||||
There are various examples of using the library for different projects in the [examples](examples) folder.
|
||||
|
498
README_sycl.md
498
README_sycl.md
@ -1,249 +1,249 @@
|
||||
# whisper.cpp for SYCL
|
||||
|
||||
[Background](#background)
|
||||
|
||||
[OS](#os)
|
||||
|
||||
[Intel GPU](#intel-gpu)
|
||||
|
||||
[Linux](#linux)
|
||||
|
||||
[Environment Variable](#environment-variable)
|
||||
|
||||
[Known Issue](#known-issue)
|
||||
|
||||
[Todo](#todo)
|
||||
|
||||
## Background
|
||||
|
||||
SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators—such as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17.
|
||||
|
||||
oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms.
|
||||
|
||||
Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs.
|
||||
|
||||
To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL.
|
||||
|
||||
The whisper.cpp for SYCL is used to support Intel GPUs.
|
||||
|
||||
For Intel CPU, recommend to use whisper.cpp for X86 (Intel MKL build).
|
||||
|
||||
## OS
|
||||
|
||||
|OS|Status|Verified|
|
||||
|-|-|-|
|
||||
|Linux|Support|Ubuntu 22.04|
|
||||
|Windows|Ongoing| |
|
||||
|
||||
|
||||
## Intel GPU
|
||||
|
||||
|Intel GPU| Status | Verified Model|
|
||||
|-|-|-|
|
||||
|Intel Data Center Max Series| Support| Max 1550|
|
||||
|Intel Data Center Flex Series| Support| Flex 170|
|
||||
|Intel Arc Series| Support| Arc 770|
|
||||
|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake|
|
||||
|Intel iGPU| Support| iGPU in i5-1250P, i7-1165G7|
|
||||
|
||||
|
||||
## Linux
|
||||
|
||||
### Setup Environment
|
||||
|
||||
1. Install Intel GPU driver.
|
||||
|
||||
a. Please install Intel GPU driver by official guide: [Install GPU Drivers](https://dgpu-docs.intel.com/driver/installation.html).
|
||||
|
||||
Note: for iGPU, please install the client GPU driver.
|
||||
|
||||
b. Add user to group: video, render.
|
||||
|
||||
```
|
||||
sudo usermod -aG render username
|
||||
sudo usermod -aG video username
|
||||
```
|
||||
|
||||
Note: re-login to enable it.
|
||||
|
||||
c. Check
|
||||
|
||||
```
|
||||
sudo apt install clinfo
|
||||
sudo clinfo -l
|
||||
```
|
||||
|
||||
Output (example):
|
||||
|
||||
```
|
||||
Platform #0: Intel(R) OpenCL Graphics
|
||||
`-- Device #0: Intel(R) Arc(TM) A770 Graphics
|
||||
|
||||
|
||||
Platform #0: Intel(R) OpenCL HD Graphics
|
||||
`-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
|
||||
```
|
||||
|
||||
2. Install Intel® oneAPI Base toolkit.
|
||||
|
||||
|
||||
a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
|
||||
|
||||
Recommend to install to default folder: **/opt/intel/oneapi**.
|
||||
|
||||
Following guide use the default folder as example. If you use other folder, please modify the following guide info with your folder.
|
||||
|
||||
b. Check
|
||||
|
||||
```
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
sycl-ls
|
||||
```
|
||||
|
||||
There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**.
|
||||
|
||||
Output (example):
|
||||
```
|
||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
|
||||
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
|
||||
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
|
||||
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
|
||||
|
||||
```
|
||||
|
||||
2. Build locally:
|
||||
|
||||
```
|
||||
mkdir -p build
|
||||
cd build
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#for FP16
|
||||
#cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON
|
||||
|
||||
#for FP32
|
||||
cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||
|
||||
#build example/main only
|
||||
#cmake --build . --config Release --target main
|
||||
|
||||
#build all binary
|
||||
cmake --build . --config Release -v
|
||||
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
./examples/sycl/build.sh
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only.
|
||||
|
||||
### Run
|
||||
|
||||
1. Put model file to folder **models**
|
||||
|
||||
2. Enable oneAPI running environment
|
||||
|
||||
```
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
3. List device ID
|
||||
|
||||
Run without parameter:
|
||||
|
||||
```
|
||||
./build/bin/ls-sycl-device
|
||||
|
||||
or
|
||||
|
||||
./build/bin/main
|
||||
```
|
||||
|
||||
Check the ID in startup log, like:
|
||||
|
||||
```
|
||||
found 4 SYCL devices:
|
||||
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3,
|
||||
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
|
||||
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2,
|
||||
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280
|
||||
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0,
|
||||
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280
|
||||
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0,
|
||||
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
|
||||
|
||||
```
|
||||
|
||||
|Attribute|Note|
|
||||
|-|-|
|
||||
|compute capability 1.3|Level-zero running time, recommended |
|
||||
|compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
|
||||
|
||||
4. Set device ID and execute whisper.cpp
|
||||
|
||||
Set device ID = 0 by **GGML_SYCL_DEVICE=0**
|
||||
|
||||
```
|
||||
GGML_SYCL_DEVICE=0 ./build/bin/main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
```
|
||||
or run by script:
|
||||
|
||||
```
|
||||
./examples/sycl/run_whisper.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
5. Check the device ID in output
|
||||
|
||||
Like:
|
||||
```
|
||||
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
|
||||
```
|
||||
|
||||
|
||||
## Environment Variable
|
||||
|
||||
#### Build
|
||||
|
||||
|Name|Value|Function|
|
||||
|-|-|-|
|
||||
|WHISPER_SYCL|ON (mandatory)|Enable build with SYCL code path. <br>For FP32/FP16, WHISPER_SYCL=ON is mandatory.|
|
||||
|WHISPER_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path.For FP32, do not set it.|
|
||||
|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path|
|
||||
|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path|
|
||||
|
||||
#### Running
|
||||
|
||||
|
||||
|Name|Value|Function|
|
||||
|-|-|-|
|
||||
|GGML_SYCL_DEVICE|0 (default) or 1|Set the device id used. Check the device ids by default running output|
|
||||
|GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG|
|
||||
|
||||
## Known Issue
|
||||
|
||||
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
|
||||
|
||||
Miss to enable oneAPI running environment.
|
||||
|
||||
Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
|
||||
|
||||
|
||||
- Hang during startup
|
||||
|
||||
llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block.
|
||||
|
||||
Solution: add **--no-mmap**.
|
||||
|
||||
## Todo
|
||||
|
||||
- Support to build in Windows.
|
||||
|
||||
- Support multiple cards.
|
||||
# whisper.cpp for SYCL
|
||||
|
||||
[Background](#background)
|
||||
|
||||
[OS](#os)
|
||||
|
||||
[Intel GPU](#intel-gpu)
|
||||
|
||||
[Linux](#linux)
|
||||
|
||||
[Environment Variable](#environment-variable)
|
||||
|
||||
[Known Issue](#known-issue)
|
||||
|
||||
[Todo](#todo)
|
||||
|
||||
## Background
|
||||
|
||||
SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators<EFBFBD>such as CPUs, GPUs, and FPGAs. It is a single-source embedded domain-specific language based on pure C++17.
|
||||
|
||||
oneAPI is a specification that is open and standards-based, supporting multiple architecture types including but not limited to GPU, CPU, and FPGA. The spec has both direct programming and API-based programming paradigms.
|
||||
|
||||
Intel uses the SYCL as direct programming language to support CPU, GPUs and FPGAs.
|
||||
|
||||
To avoid re-inventing the wheel, this code refers other code paths in llama.cpp (like OpenBLAS, cuBLAS, CLBlast). We use a open-source tool [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) (Commercial release [Intel<EFBFBD> DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) migrate to SYCL.
|
||||
|
||||
The whisper.cpp for SYCL is used to support Intel GPUs.
|
||||
|
||||
For Intel CPU, recommend to use whisper.cpp for X86 (Intel MKL build).
|
||||
|
||||
## OS
|
||||
|
||||
|OS|Status|Verified|
|
||||
|-|-|-|
|
||||
|Linux|Support|Ubuntu 22.04|
|
||||
|Windows|Ongoing| |
|
||||
|
||||
|
||||
## Intel GPU
|
||||
|
||||
|Intel GPU| Status | Verified Model|
|
||||
|-|-|-|
|
||||
|Intel Data Center Max Series| Support| Max 1550|
|
||||
|Intel Data Center Flex Series| Support| Flex 170|
|
||||
|Intel Arc Series| Support| Arc 770|
|
||||
|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake|
|
||||
|Intel iGPU| Support| iGPU in i5-1250P, i7-1165G7|
|
||||
|
||||
|
||||
## Linux
|
||||
|
||||
### Setup Environment
|
||||
|
||||
1. Install Intel GPU driver.
|
||||
|
||||
a. Please install Intel GPU driver by official guide: [Install GPU Drivers](https://dgpu-docs.intel.com/driver/installation.html).
|
||||
|
||||
Note: for iGPU, please install the client GPU driver.
|
||||
|
||||
b. Add user to group: video, render.
|
||||
|
||||
```
|
||||
sudo usermod -aG render username
|
||||
sudo usermod -aG video username
|
||||
```
|
||||
|
||||
Note: re-login to enable it.
|
||||
|
||||
c. Check
|
||||
|
||||
```
|
||||
sudo apt install clinfo
|
||||
sudo clinfo -l
|
||||
```
|
||||
|
||||
Output (example):
|
||||
|
||||
```
|
||||
Platform #0: Intel(R) OpenCL Graphics
|
||||
`-- Device #0: Intel(R) Arc(TM) A770 Graphics
|
||||
|
||||
|
||||
Platform #0: Intel(R) OpenCL HD Graphics
|
||||
`-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
|
||||
```
|
||||
|
||||
2. Install Intel<EFBFBD> oneAPI Base toolkit.
|
||||
|
||||
|
||||
a. Please follow the procedure in [Get the Intel<EFBFBD> oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
|
||||
|
||||
Recommend to install to default folder: **/opt/intel/oneapi**.
|
||||
|
||||
Following guide use the default folder as example. If you use other folder, please modify the following guide info with your folder.
|
||||
|
||||
b. Check
|
||||
|
||||
```
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
sycl-ls
|
||||
```
|
||||
|
||||
There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**.
|
||||
|
||||
Output (example):
|
||||
```
|
||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
|
||||
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
|
||||
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
|
||||
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
|
||||
|
||||
```
|
||||
|
||||
2. Build locally:
|
||||
|
||||
```
|
||||
mkdir -p build
|
||||
cd build
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
|
||||
#for FP16
|
||||
#cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DWHISPER_SYCL_F16=ON
|
||||
|
||||
#for FP32
|
||||
cmake .. -DWHISPER_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||
|
||||
#build example/main only
|
||||
#cmake --build . --config Release --target main
|
||||
|
||||
#build all binary
|
||||
cmake --build . --config Release -v
|
||||
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
./examples/sycl/build.sh
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only.
|
||||
|
||||
### Run
|
||||
|
||||
1. Put model file to folder **models**
|
||||
|
||||
2. Enable oneAPI running environment
|
||||
|
||||
```
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
3. List device ID
|
||||
|
||||
Run without parameter:
|
||||
|
||||
```
|
||||
./build/bin/ls-sycl-device
|
||||
|
||||
or
|
||||
|
||||
./build/bin/main
|
||||
```
|
||||
|
||||
Check the ID in startup log, like:
|
||||
|
||||
```
|
||||
found 4 SYCL devices:
|
||||
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3,
|
||||
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
|
||||
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2,
|
||||
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280
|
||||
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0,
|
||||
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280
|
||||
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0,
|
||||
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
|
||||
|
||||
```
|
||||
|
||||
|Attribute|Note|
|
||||
|-|-|
|
||||
|compute capability 1.3|Level-zero running time, recommended |
|
||||
|compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
|
||||
|
||||
4. Set device ID and execute whisper.cpp
|
||||
|
||||
Set device ID = 0 by **GGML_SYCL_DEVICE=0**
|
||||
|
||||
```
|
||||
GGML_SYCL_DEVICE=0 ./build/bin/main -m models/ggml-base.en.bin -f samples/jfk.wav
|
||||
```
|
||||
or run by script:
|
||||
|
||||
```
|
||||
./examples/sycl/run_whisper.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
5. Check the device ID in output
|
||||
|
||||
Like:
|
||||
```
|
||||
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
|
||||
```
|
||||
|
||||
|
||||
## Environment Variable
|
||||
|
||||
#### Build
|
||||
|
||||
|Name|Value|Function|
|
||||
|-|-|-|
|
||||
|WHISPER_SYCL|ON (mandatory)|Enable build with SYCL code path. <br>For FP32/FP16, WHISPER_SYCL=ON is mandatory.|
|
||||
|WHISPER_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path.For FP32, do not set it.|
|
||||
|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path|
|
||||
|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path|
|
||||
|
||||
#### Running
|
||||
|
||||
|
||||
|Name|Value|Function|
|
||||
|-|-|-|
|
||||
|GGML_SYCL_DEVICE|0 (default) or 1|Set the device id used. Check the device ids by default running output|
|
||||
|GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG|
|
||||
|
||||
## Known Issue
|
||||
|
||||
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
|
||||
|
||||
Miss to enable oneAPI running environment.
|
||||
|
||||
Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
|
||||
|
||||
|
||||
- Hang during startup
|
||||
|
||||
llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block.
|
||||
|
||||
Solution: add **--no-mmap**.
|
||||
|
||||
## Todo
|
||||
|
||||
- Support to build in Windows.
|
||||
|
||||
- Support multiple cards.
|
@ -23,42 +23,26 @@ import io.github.ggerganov.whispercpp.WhisperCpp;
|
||||
public class Example {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
WhisperCpp whisper = new WhisperCpp();
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
long context = whisper.initContext("base.en");
|
||||
try {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
whisper.initContext("../ggml-base.en.bin");
|
||||
WhisperFullParams.ByValue whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
|
||||
// custom configuration if required
|
||||
//whisperParams.n_threads = 8;
|
||||
whisperParams.temperature = 0.0f;
|
||||
whisperParams.temperature_inc = 0.2f;
|
||||
//whisperParams.language = "en";
|
||||
|
||||
float[] samples = readAudio(); // divide each value by 32767.0f
|
||||
List<WhisperSegment> whisperSegmentList = whisper.fullTranscribeWithTime(whisperParams, samples);
|
||||
|
||||
for (WhisperSegment whisperSegment : whisperSegmentList) {
|
||||
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// custom configuration if required
|
||||
whisperParams.temperature_inc = 0f;
|
||||
|
||||
long start = whisperSegment.getStart();
|
||||
long end = whisperSegment.getEnd();
|
||||
var samples = readAudio(); // divide each value by 32767.0f
|
||||
whisper.fullTranscribe(whisperParams, samples);
|
||||
|
||||
String text = whisperSegment.getSentence();
|
||||
|
||||
System.out.println("start: "+start);
|
||||
System.out.println("end: "+end);
|
||||
System.out.println("text: "+text);
|
||||
|
||||
int segmentCount = whisper.getTextSegmentCount(context);
|
||||
for (int i = 0; i < segmentCount; i++) {
|
||||
String text = whisper.getTextSegment(context, i);
|
||||
System.out.println(segment.getText());
|
||||
}
|
||||
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
whisper.close();
|
||||
whisper.freeContext(context);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
```
|
||||
|
@ -168,26 +168,23 @@ public class WhisperCpp implements AutoCloseable {
|
||||
return str.toString().trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Full transcribe with time list.
|
||||
*
|
||||
* @param whisperParams the whisper params
|
||||
* @param audioData the audio data
|
||||
* @return the list
|
||||
* @throws IOException the io exception
|
||||
*/
|
||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException {
|
||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
|
||||
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
|
||||
valueParams.read();
|
||||
|
||||
if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
int nSegments = lib.whisper_full_n_segments(ctx);
|
||||
List<WhisperSegment> segments= new ArrayList<>(nSegments);
|
||||
|
||||
|
||||
for (int i = 0; i < nSegments; i++) {
|
||||
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
|
||||
String text = lib.whisper_full_get_segment_text(ctx, i);
|
||||
|
@ -118,7 +118,7 @@ class WhisperCppTest {
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.7.6",
|
||||
"version": "1.7.5",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
9
bindings/ruby/.gitignore
vendored
9
bindings/ruby/.gitignore
vendored
@ -1,9 +1,6 @@
|
||||
LICENSE
|
||||
pkg/
|
||||
lib/whisper.*
|
||||
ext/examples/
|
||||
ext/ggml/
|
||||
ext/include/
|
||||
ext/scripts/
|
||||
ext/src/
|
||||
test/fixtures/
|
||||
ext/sources/*
|
||||
!ext/sources/CMakeGraphVizOptions.cmake
|
||||
ext/mkmf.log
|
||||
|
@ -24,21 +24,7 @@ or,
|
||||
|
||||
$ gem install whispercpp -- --enable-ggml-cuda
|
||||
|
||||
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options, for example:
|
||||
|
||||
Boolean options:
|
||||
|
||||
* `-DGGML_BLAS=1` -> `--enable-ggml-blas`
|
||||
* `-DWHISER_COREML=OFF` -> `--disable-whisper-coreml`
|
||||
|
||||
Argument options:
|
||||
|
||||
* `-DGGML_CUDA_COMPRESSION_MODE=size` -> `--ggml-cuda-compression-mode=size`
|
||||
|
||||
Combination:
|
||||
|
||||
* `-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES="86"` -> `--enable-ggml-cuda --cmake_cuda-architectures="86"`
|
||||
|
||||
See whisper.cpp's [README](https://github.com/ggml-org/whisper.cpp/blob/master/README.md) for available options. You need convert options present the README to Ruby-style options.
|
||||
For boolean options like `GGML_CUDA`, the README says `-DGGML_CUDA=1`. You need strip `-D`, prepend `--enable-` for `1` or `ON` (`--disable-` for `0` or `OFF`) and make it kebab-case: `--enable-ggml-cuda`.
|
||||
For options which require arguments like `CMAKE_CUDA_ARCHITECTURES`, the README says `-DCMAKE_CUDA_ARCHITECTURES="86"`. You need strip `-D`, prepend `--`, make it kebab-case, append `=` and append argument: `--cmake-cuda-architectures="86"`.
|
||||
|
||||
@ -70,6 +56,17 @@ end
|
||||
|
||||
Some models are prepared up-front:
|
||||
|
||||
```ruby
|
||||
base_en = Whisper::Model.pre_converted_models["base.en"]
|
||||
whisper = Whisper::Context.new(base_en)
|
||||
```
|
||||
|
||||
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
|
||||
|
||||
```ruby
|
||||
Whisper::Model.pre_converted_models["base"].clear_cache
|
||||
```
|
||||
|
||||
You also can use shorthand for pre-converted models:
|
||||
|
||||
```ruby
|
||||
@ -94,19 +91,6 @@ puts Whisper::Model.pre_converted_models.keys
|
||||
# :
|
||||
```
|
||||
|
||||
You can also retrieve each model:
|
||||
|
||||
```ruby
|
||||
base_en = Whisper::Model.pre_converted_models["base.en"]
|
||||
whisper = Whisper::Context.new(base_en)
|
||||
```
|
||||
|
||||
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
|
||||
|
||||
```ruby
|
||||
Whisper::Model.pre_converted_models["base"].clear_cache
|
||||
```
|
||||
|
||||
You can also use local model files you prepared:
|
||||
|
||||
```ruby
|
||||
@ -127,80 +111,9 @@ See [models][] page for details.
|
||||
|
||||
Currently, whisper.cpp accepts only 16-bit WAV files.
|
||||
|
||||
### Voice Activity Detection (VAD) ###
|
||||
|
||||
Support for Voice Activity Detection (VAD) can be enabled by setting `Whisper::Params`'s `vad` argument to `true` and specifying VAD model:
|
||||
|
||||
```ruby
|
||||
Whisper::Params.new(
|
||||
vad: true,
|
||||
vad_model_path: "silero-v5.1.2",
|
||||
# other arguments...
|
||||
)
|
||||
```
|
||||
|
||||
When you pass the model name (`"silero-v5.1.2"`) or URI (`https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin`), it will be downloaded automatically.
|
||||
Currently, "silero-v5.1.2" is registered as pre-converted model like ASR models. You also specify file path or URI of model.
|
||||
|
||||
If you need configure VAD behavior, pass params for that:
|
||||
|
||||
```ruby
|
||||
Whisper::Params.new(
|
||||
vad: true,
|
||||
vad_model_path: "silero-v5.1.2",
|
||||
vad_params: Whisper::VAD::Params.new(
|
||||
threshold: 1.0, # defaults to 0.5
|
||||
min_speech_duration_ms: 500, # defaults to 250
|
||||
min_silence_duration_ms: 200, # defaults to 100
|
||||
max_speech_duration_s: 30000, # default is FLT_MAX,
|
||||
speech_pad_ms: 50, # defaults to 30
|
||||
samples_overlap: 0.5 # defaults to 0.1
|
||||
),
|
||||
# other arguments...
|
||||
)
|
||||
```
|
||||
|
||||
For details on VAD, see [whisper.cpp's README](https://github.com/ggml-org/whisper.cpp?tab=readme-ov-file#voice-activity-detection-vad).
|
||||
|
||||
### Output ###
|
||||
|
||||
whispercpp supports SRT and WebVTT output:
|
||||
|
||||
```ruby
|
||||
puts whisper.transcribe("path/to/audio.wav", Whisper::Params.new).to_webvtt
|
||||
# =>
|
||||
WEBVTT
|
||||
|
||||
1
|
||||
00:00:00.000 --> 00:00:03.860
|
||||
My thought I have nobody by a beauty and will as you poured.
|
||||
|
||||
2
|
||||
00:00:03.860 --> 00:00:09.840
|
||||
Mr. Rochester is sub in that so-don't find simplest, and devoted about, to let might in
|
||||
|
||||
3
|
||||
00:00:09.840 --> 00:00:09.940
|
||||
a
|
||||
|
||||
```
|
||||
|
||||
You may call `#to_srt`, too
|
||||
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
### Transcription ###
|
||||
|
||||
By default, `Whisper::Context#transcribe` works in a single thread. You can make it work in parallel by passing `n_processors` option:
|
||||
|
||||
```ruby
|
||||
whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors)
|
||||
```
|
||||
|
||||
Note that transcription occasionally might be low accuracy when it works in parallel.
|
||||
|
||||
### Segments ###
|
||||
|
||||
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
|
||||
@ -222,7 +135,7 @@ whisper
|
||||
ed: format_time(segment.end_time),
|
||||
text: segment.text
|
||||
}
|
||||
line << " (speaker turned)" if segment.speaker_turn_next?
|
||||
line << " (speaker turned)" if segment.speaker_next_turn?
|
||||
puts line
|
||||
end
|
||||
|
||||
@ -238,7 +151,7 @@ params.on_new_segment do |segment|
|
||||
ed: format_time(segment.end_time),
|
||||
text: segment.text
|
||||
}
|
||||
line << " (speaker turned)" if segment.speaker_turn_next?
|
||||
line << " (speaker turned)" if segment.speaker_next_turn?
|
||||
puts line
|
||||
end
|
||||
|
||||
@ -335,11 +248,6 @@ First call of `rake test` builds an extension and downloads a model for testing.
|
||||
|
||||
If something seems wrong on build, running `rake clean` solves some cases.
|
||||
|
||||
### Need help ###
|
||||
|
||||
* Windows support
|
||||
* Refinement of C/C++ code, especially memory management
|
||||
|
||||
License
|
||||
-------
|
||||
|
||||
|
@ -67,30 +67,17 @@ file LIB_FILE => [SO_FILE, "lib"] do |t|
|
||||
end
|
||||
CLEAN.include LIB_FILE
|
||||
|
||||
Rake::TestTask.new
|
||||
|
||||
TEST_FIXTURE_AUDIO = "test/fixtures/jfk.wav"
|
||||
TEST_FIXTURE_AUDIO_SRC = File.expand_path(File.join(__dir__, "..", "..", "samples", "jfk.wav"))
|
||||
TEST_FIXTURE_AUDIO_DIR = TEST_FIXTURE_AUDIO.pathmap("%d")
|
||||
directory TEST_FIXTURE_AUDIO_DIR
|
||||
if File.exist? TEST_FIXTURE_AUDIO_SRC
|
||||
file TEST_FIXTURE_AUDIO => [TEST_FIXTURE_AUDIO_SRC, TEST_FIXTURE_AUDIO_DIR] do |t|
|
||||
symlink t.source, t.name
|
||||
end
|
||||
else
|
||||
require "open-uri"
|
||||
file TEST_FIXTURE_AUDIO => TEST_FIXTURE_AUDIO_DIR do |t|
|
||||
File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/samples/jfk.wav").read
|
||||
end
|
||||
Rake::TestTask.new do |t|
|
||||
t.test_files = FileList["tests/test_*.rb"]
|
||||
end
|
||||
|
||||
TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
|
||||
file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t|
|
||||
chdir "test/jfk_reader" do
|
||||
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
|
||||
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
|
||||
chdir "tests/jfk_reader" do
|
||||
ruby "extconf.rb"
|
||||
sh "make"
|
||||
end
|
||||
end
|
||||
CLEAN.include TEST_MEMORY_VIEW
|
||||
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
|
||||
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO]
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW]
|
||||
|
10
bindings/ruby/ext/.gitignore
vendored
10
bindings/ruby/ext/.gitignore
vendored
@ -2,8 +2,10 @@ Makefile
|
||||
whisper.so
|
||||
whisper.bundle
|
||||
whisper.dll
|
||||
scripts/get-flags.mk
|
||||
*.o
|
||||
*.a
|
||||
sources/*
|
||||
!sources/CMakeGraphVizOptions.cmake
|
||||
mkmf.log
|
||||
/*/**/*.c
|
||||
/*/**/*.cpp
|
||||
/*/**/*.h
|
||||
/*/**/*.m
|
||||
/*/**/*.metal
|
||||
|
@ -1,32 +1,16 @@
|
||||
require "tsort"
|
||||
|
||||
class Dependencies
|
||||
include TSort
|
||||
|
||||
def initialize(cmake, options)
|
||||
@cmake = cmake
|
||||
@options = options
|
||||
@static_lib_shape = nil
|
||||
@nodes = {}
|
||||
@graph = Hash.new {|h, k| h[k] = []}
|
||||
|
||||
generate_dot
|
||||
parse_dot
|
||||
end
|
||||
|
||||
def libs
|
||||
tsort.filter_map {|node|
|
||||
label, shape = @nodes[node]
|
||||
if shape == @static_lib_shape
|
||||
label.gsub(/\\n\([^)]+\)/, '')
|
||||
else
|
||||
nil
|
||||
end
|
||||
}.reverse.collect {|lib| "lib#{lib}.a"}
|
||||
@libs = parse_dot
|
||||
end
|
||||
|
||||
def to_s
|
||||
libs.join(" ")
|
||||
@libs.join(" ")
|
||||
end
|
||||
|
||||
private
|
||||
@ -36,38 +20,42 @@ class Dependencies
|
||||
end
|
||||
|
||||
def generate_dot
|
||||
args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"]
|
||||
args << @options.to_s unless @options.to_s.empty?
|
||||
system @cmake, *args, exception: true
|
||||
system @cmake, "-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF", @options.to_s, exception: true
|
||||
end
|
||||
|
||||
def parse_dot
|
||||
static_lib_shape = nil
|
||||
nodes = {}
|
||||
depends = Hash.new {|h, k| h[k] = []}
|
||||
|
||||
class << depends
|
||||
include TSort
|
||||
alias tsort_each_node each_key
|
||||
def tsort_each_child(node, &block)
|
||||
fetch(node, []).each(&block)
|
||||
end
|
||||
end
|
||||
|
||||
File.open(dot_path).each_line do |line|
|
||||
case line
|
||||
when /\[\s*label\s*=\s*"Static Library"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]/
|
||||
@static_lib_shape = $~[:shape]
|
||||
static_lib_shape = $~[:shape]
|
||||
when /\A\s*"(?<node>\w+)"\s*\[\s*label\s*=\s*"(?<label>\S+)"\s*,\s*shape\s*=\s*(?<shape>\w+)\s*\]\s*;\s*\z/
|
||||
node = $~[:node]
|
||||
label = $~[:label]
|
||||
shape = $~[:shape]
|
||||
@nodes[node] = [label, shape]
|
||||
nodes[node] = [label, shape]
|
||||
when /\A\s*"(?<depender>\w+)"\s*->\s*"(?<dependee>\w+)"/
|
||||
depender = $~[:depender]
|
||||
dependee = $~[:dependee]
|
||||
@graph[depender] << dependee
|
||||
depends[depender] ||= []
|
||||
depends[depender] << dependee
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def tsort_each_node
|
||||
@nodes.each_key do |node|
|
||||
yield node
|
||||
end
|
||||
end
|
||||
|
||||
def tsort_each_child(node)
|
||||
@graph[node].each do |child|
|
||||
yield child
|
||||
end
|
||||
depends.tsort.filter_map {|node|
|
||||
label, shape = nodes[node]
|
||||
shape == static_lib_shape ? label : nil
|
||||
}.collect {|lib| "lib#{lib}.a"}
|
||||
.reverse
|
||||
end
|
||||
end
|
||||
|
@ -3,7 +3,7 @@ require_relative "options"
|
||||
require_relative "dependencies"
|
||||
|
||||
cmake = find_executable("cmake") || abort
|
||||
options = Options.new(cmake)
|
||||
options = Options.new
|
||||
have_library("gomp") rescue nil
|
||||
libs = Dependencies.new(cmake, options)
|
||||
|
||||
|
@ -1,11 +1,25 @@
|
||||
class Options
|
||||
def initialize(cmake="cmake")
|
||||
@cmake = cmake
|
||||
def initialize
|
||||
@options = {}
|
||||
@pending_options = []
|
||||
@ignored_options = []
|
||||
|
||||
configure
|
||||
end
|
||||
|
||||
def help
|
||||
@options
|
||||
.collect_concat {|name, (type, value)|
|
||||
option = option_name(name)
|
||||
if type == :bool
|
||||
["--enable-#{option}", "--disable-#{option}"]
|
||||
else
|
||||
"--#{option}=#{type.upcase}"
|
||||
end
|
||||
}
|
||||
.join($/)
|
||||
end
|
||||
|
||||
def to_s
|
||||
@options
|
||||
.reject {|name, (type, value)| value.nil?}
|
||||
@ -18,68 +32,188 @@ class Options
|
||||
|
||||
output = nil
|
||||
Dir.chdir __dir__ do
|
||||
output = `#{@cmake.shellescape} -S sources -B build -L`
|
||||
output = `cmake -S sources -B build -L`
|
||||
end
|
||||
@cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1)
|
||||
.filter_map {|line|
|
||||
option, value = line.chomp.split("=", 2)
|
||||
name, type = option.split(":", 2)
|
||||
[
|
||||
name,
|
||||
[
|
||||
type,
|
||||
type == "BOOL" ? value == "ON" : value
|
||||
]
|
||||
]
|
||||
}.to_h
|
||||
started = false
|
||||
@cmake_options = output.lines.filter_map {|line|
|
||||
if line.chomp == "-- Cache values"
|
||||
started = true
|
||||
next
|
||||
end
|
||||
next unless started
|
||||
option, value = line.chomp.split("=", 2)
|
||||
name, type = option.split(":", 2)
|
||||
[name, type, value]
|
||||
}
|
||||
end
|
||||
|
||||
def missing_options
|
||||
cmake_options.collect {|name, type, value| name} -
|
||||
@options.keys - @pending_options - @ignored_options
|
||||
end
|
||||
|
||||
def extra_options
|
||||
@options.keys + @pending_options - @ignored_options -
|
||||
cmake_options.collect {|name, type, value| name}
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def configure
|
||||
cmake_options.each_pair do |name, (type, default_value)|
|
||||
option = option_name(name)
|
||||
value = type == "BOOL" ? enable_config(option) : arg_config("--#{option}")
|
||||
@options[name] = [type, value]
|
||||
end
|
||||
|
||||
configure_accelerate
|
||||
configure_metal
|
||||
configure_coreml
|
||||
end
|
||||
|
||||
# See ggml/src/ggml-cpu/CMakeLists.txt
|
||||
def configure_accelerate
|
||||
if RUBY_PLATFORM.match?(/darwin/) && enabled?("GGML_ACCELERATE")
|
||||
$LDFLAGS << " -framework Accelerate"
|
||||
end
|
||||
end
|
||||
|
||||
# See ggml/src/ggml-metal/CMakeLists.txt
|
||||
def configure_metal
|
||||
$LDFLAGS << " -framework Foundation -framework Metal -framework MetalKit" if enabled?("GGML_METAL")
|
||||
end
|
||||
|
||||
# See src/CmakeLists.txt
|
||||
def configure_coreml
|
||||
if enabled?("WHISPER_COREML")
|
||||
$LDFLAGS << " -framework Foundation -framework CoreML"
|
||||
$defs << "-DRUBY_WHISPER_USE_COREML"
|
||||
end
|
||||
filepath "ACCELERATE_FRAMEWORK"
|
||||
ignored "BUILD_SHARED_LIBS"
|
||||
ignored "BUILD_TESTING"
|
||||
ignored "CMAKE_BUILD_TYPE"
|
||||
ignored "CMAKE_INSTALL_PREFIX"
|
||||
string "CMAKE_OSX_ARCHITECTURES"
|
||||
ignored "CMAKE_OSX_DEPLOYMENT_TARGET"
|
||||
string "CMAKE_OSX_SYSROOT"
|
||||
filepath "FOUNDATION_LIBRARY"
|
||||
bool "GGML_ACCELERATE"
|
||||
bool "GGML_ALL_WARNINGS_3RD_PARTY"
|
||||
bool "GGML_AMX_BF16"
|
||||
bool "GGML_AMX_INT8"
|
||||
bool "GGML_AMX_TILE"
|
||||
bool "GGML_AVX"
|
||||
bool "GGML_AVX2"
|
||||
bool "GGML_AVX512"
|
||||
bool "GGML_AVX512_BF16"
|
||||
bool "GGML_AVX512_VBMI"
|
||||
bool "GGML_AVX512_VNNI"
|
||||
bool "GGML_AVX_VNNI"
|
||||
ignored "GGML_BACKEND_DL"
|
||||
ignored "GGML_BIN_INSTALL_DIR"
|
||||
bool "GGML_BLAS"
|
||||
string "GGML_BLAS_VENDOR"
|
||||
bool "GGML_BMI2"
|
||||
ignored "GGML_BUILD_EXAMPLES"
|
||||
ignored "GGML_BUILD_TESTS"
|
||||
filepath "GGML_CCACHE_FOUND"
|
||||
bool "GGML_CPU"
|
||||
bool "GGML_CPU_AARCH64"
|
||||
ignored "GGML_CPU_ALL_VARIANTS"
|
||||
string "GGML_CPU_ARM_ARCH"
|
||||
bool "GGML_CPU_HBM"
|
||||
bool "GGML_CPU_KLEIDIAI"
|
||||
string "GGML_CPU_POWERPC_CPUTYPE"
|
||||
bool "GGML_CUDA"
|
||||
string "GGML_CUDA_COMPRESSION_MODE"
|
||||
bool "GGML_CUDA_F16"
|
||||
bool "GGML_CUDA_FA"
|
||||
bool "GGML_CUDA_FA_ALL_QUANTS"
|
||||
bool "GGML_CUDA_FORCE_CUBLAS"
|
||||
bool "GGML_CUDA_FORCE_MMQ"
|
||||
ignored "GGML_CUDA_GRAPHS"
|
||||
bool "GGML_CUDA_NO_PEER_COPY"
|
||||
bool "GGML_CUDA_NO_VMM"
|
||||
string "GGML_CUDA_PEER_MAX_BATCH_SIZE"
|
||||
bool "GGML_F16C"
|
||||
bool "GGML_FMA"
|
||||
bool "GGML_GPROF"
|
||||
bool "GGML_HIP"
|
||||
bool "GGML_HIP_GRAPHS"
|
||||
bool "GGML_HIP_NO_VMM"
|
||||
bool "GGML_HIP_ROCWMMA_FATTN"
|
||||
ignored "GGML_INCLUDE_INSTALL_DIR"
|
||||
bool "GGML_KOMPUTE"
|
||||
bool "GGML_LASX"
|
||||
ignored "GGML_LIB_INSTALL_DIR"
|
||||
ignored "GGML_LLAMAFILE"
|
||||
bool "GGML_LSX"
|
||||
bool "GGML_LTO"
|
||||
bool "GGML_METAL"
|
||||
bool "GGML_METAL_EMBED_LIBRARY"
|
||||
string "GGML_METAL_MACOSX_VERSION_MIN"
|
||||
bool "GGML_METAL_NDEBUG"
|
||||
bool "GGML_METAL_SHADER_DEBUG"
|
||||
string "GGML_METAL_STD"
|
||||
bool "GGML_METAL_USE_BF16"
|
||||
bool "GGML_MUSA"
|
||||
bool "GGML_NATIVE"
|
||||
bool "GGML_OPENCL"
|
||||
bool "GGML_OPENCL_EMBED_KERNELS"
|
||||
bool "GGML_OPENCL_PROFILING"
|
||||
string "GGML_OPENCL_TARGET_VERSION"
|
||||
bool "GGML_OPENCL_USE_ADRENO_KERNELS"
|
||||
bool "GGML_OPENMP"
|
||||
bool "GGML_RPC"
|
||||
bool "GGML_RVV"
|
||||
bool "GGML_RV_ZFH"
|
||||
pending "GGML_SCCACHE_FOUND"
|
||||
string "GGML_SCHED_MAX_COPIES"
|
||||
bool "GGML_SSE42"
|
||||
ignored "GGML_STATIC"
|
||||
bool "GGML_SYCL"
|
||||
string "GGML_SYCL_DEVICE_ARCH"
|
||||
bool "GGML_SYCL_F16"
|
||||
bool "GGML_SYCL_GRAPH"
|
||||
string "GGML_SYCL_TARGET"
|
||||
bool "GGML_VULKAN"
|
||||
bool "GGML_VULKAN_CHECK_RESULTS"
|
||||
bool "GGML_VULKAN_DEBUG"
|
||||
bool "GGML_VULKAN_MEMORY_DEBUG"
|
||||
bool "GGML_VULKAN_PERF"
|
||||
ignored "GGML_VULKAN_RUN_TESTS"
|
||||
filepath "GGML_VULKAN_SHADERS_GEN_TOOLCHAIN"
|
||||
bool "GGML_VULKAN_SHADER_DEBUG_INFO"
|
||||
pending "GGML_VULKAN_VALIDATE"
|
||||
bool "GGML_VXE"
|
||||
filepath "GIT_EXE"
|
||||
filepath "MATH_LIBRARY"
|
||||
filepath "METALKIT_FRAMEWORK"
|
||||
filepath "METAL_FRAMEWORK"
|
||||
bool "WHISPER_ALL_WARNINGS"
|
||||
bool "WHISPER_ALL_WARNINGS_3RD_PARTY"
|
||||
ignored "WHISPER_BIN_INSTALL_DIR"
|
||||
ignored "WHISPER_BUILD_EXAMPLES"
|
||||
ignored "WHISPER_BUILD_SERVER"
|
||||
ignored"WHISPER_BUILD_TESTS"
|
||||
bool "WHISPER_CCACHE"
|
||||
bool "WHISPER_COREML"
|
||||
bool "WHISPER_COREML_ALLOW_FALLBACK"
|
||||
ignored "WHISPER_CURL"
|
||||
bool "WHISPER_FATAL_WARNINGS"
|
||||
ignored "WHISPER_FFMPEG"
|
||||
ignored "WHISPER_INCLUDE_INSTALL_DIR"
|
||||
ignored "WHISPER_LIB_INSTALL_DIR"
|
||||
bool "WHISPER_OPENVINO"
|
||||
bool "WHISPER_SANITIZE_ADDRESS"
|
||||
bool "WHISPER_SANITIZE_THREAD"
|
||||
bool "WHISPER_SANITIZE_UNDEFINED"
|
||||
ignored "WHISPER_SDL2"
|
||||
pending "WHISPER_USE_SYSTEM_GGML"
|
||||
end
|
||||
|
||||
def option_name(name)
|
||||
name.downcase.gsub("_", "-")
|
||||
end
|
||||
|
||||
def enabled?(option)
|
||||
op = @options[option]
|
||||
raise "Option not exist: #{option}" unless op
|
||||
raise "Option not boolean: #{option}(#{op[0]})" unless op[0] == "BOOL"
|
||||
if op[1].nil?
|
||||
cmake_options[option][1]
|
||||
else
|
||||
op[1]
|
||||
end
|
||||
def bool(name)
|
||||
option = option_name(name)
|
||||
value = enable_config(option)
|
||||
@options[name] = [:bool, value]
|
||||
end
|
||||
|
||||
def string(name, type=:string)
|
||||
option = "--#{option_name(name)}"
|
||||
value = arg_config(option)
|
||||
raise "String expected for #{option}" if value == true || value&.empty?
|
||||
@options[name] = [type, value]
|
||||
end
|
||||
|
||||
def path(name)
|
||||
string(name, :path)
|
||||
end
|
||||
|
||||
def filepath(name)
|
||||
string(name, :filepath)
|
||||
end
|
||||
|
||||
def pending(name)
|
||||
@pending_options << name
|
||||
end
|
||||
|
||||
def ignored(name)
|
||||
@ignored_options << name
|
||||
end
|
||||
end
|
||||
|
@ -3,10 +3,8 @@
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
VALUE mWhisper;
|
||||
VALUE mVAD;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
VALUE cVADParams;
|
||||
VALUE eError;
|
||||
|
||||
VALUE cSegment;
|
||||
@ -22,9 +20,6 @@ ID id_new;
|
||||
ID id_to_path;
|
||||
ID id_URI;
|
||||
ID id_pre_converted_models;
|
||||
ID id_coreml_compiled_models;
|
||||
ID id_cache;
|
||||
ID id_n_processors;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
|
||||
@ -36,7 +31,6 @@ extern void init_ruby_whisper_params(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_error(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
|
||||
extern void init_ruby_whisper_model(VALUE *mWhisper);
|
||||
extern void init_ruby_whisper_vad_params(VALUE *mVAD);
|
||||
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
/*
|
||||
@ -86,14 +80,6 @@ static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
|
||||
return rb_str_new2(str_full);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* system_info_str -> String
|
||||
*/
|
||||
static VALUE ruby_whisper_s_system_info_str(VALUE self) {
|
||||
return rb_str_new2(whisper_print_system_info());
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
|
||||
is_log_callback_finalized = true;
|
||||
return Qnil;
|
||||
@ -130,6 +116,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
|
||||
rb_gc_mark(rwm->context);
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_model_allocate(VALUE klass) {
|
||||
ruby_whisper_model *rwm;
|
||||
rwm = ALLOC(ruby_whisper_model);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
|
||||
}
|
||||
|
||||
void Init_whisper() {
|
||||
id_to_s = rb_intern("to_s");
|
||||
id_call = rb_intern("call");
|
||||
@ -141,14 +137,9 @@ void Init_whisper() {
|
||||
id_to_path = rb_intern("to_path");
|
||||
id_URI = rb_intern("URI");
|
||||
id_pre_converted_models = rb_intern("pre_converted_models");
|
||||
id_coreml_compiled_models = rb_intern("coreml_compiled_models");
|
||||
id_cache = rb_intern("cache");
|
||||
id_n_processors = rb_intern("n_processors");
|
||||
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
mVAD = rb_define_module_under(mWhisper, "VAD");
|
||||
|
||||
rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version()));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
|
||||
@ -160,7 +151,6 @@ void Init_whisper() {
|
||||
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
|
||||
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
|
||||
rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0);
|
||||
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
|
||||
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
|
||||
|
||||
@ -169,9 +159,6 @@ void Init_whisper() {
|
||||
init_ruby_whisper_error(&mWhisper);
|
||||
init_ruby_whisper_segment(&mWhisper, &cContext);
|
||||
init_ruby_whisper_model(&mWhisper);
|
||||
init_ruby_whisper_vad_params(&mVAD);
|
||||
|
||||
rb_require("whisper/context");
|
||||
rb_require("whisper/segment");
|
||||
rb_require("whisper/model/uri");
|
||||
}
|
||||
|
@ -21,13 +21,8 @@ typedef struct {
|
||||
ruby_whisper_callback_container *progress_callback_container;
|
||||
ruby_whisper_callback_container *encoder_begin_callback_container;
|
||||
ruby_whisper_callback_container *abort_callback_container;
|
||||
VALUE vad_params;
|
||||
} ruby_whisper_params;
|
||||
|
||||
typedef struct {
|
||||
struct whisper_vad_params params;
|
||||
} ruby_whisper_vad_params;
|
||||
|
||||
typedef struct {
|
||||
VALUE context;
|
||||
int index;
|
||||
|
@ -11,21 +11,15 @@ extern ID id_new;
|
||||
extern ID id_to_path;
|
||||
extern ID id_URI;
|
||||
extern ID id_pre_converted_models;
|
||||
extern ID id_coreml_compiled_models;
|
||||
extern ID id_cache;
|
||||
extern ID id_n_processors;
|
||||
|
||||
extern VALUE cContext;
|
||||
extern VALUE eError;
|
||||
extern VALUE cModel;
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_params_type;
|
||||
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
|
||||
extern VALUE rb_whisper_model_s_new(VALUE context);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
ID transcribe_option_names[1];
|
||||
extern VALUE rb_whisper_model_initialize(VALUE context);
|
||||
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
|
||||
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
|
||||
|
||||
static void
|
||||
ruby_whisper_free(ruby_whisper *rw)
|
||||
@ -43,74 +37,19 @@ rb_whisper_mark(ruby_whisper *rw)
|
||||
}
|
||||
|
||||
void
|
||||
rb_whisper_free(void *p)
|
||||
rb_whisper_free(ruby_whisper *rw)
|
||||
{
|
||||
ruby_whisper *rw = (ruby_whisper *)p;
|
||||
ruby_whisper_free(rw);
|
||||
free(rw);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper *rw = (const ruby_whisper *)p;
|
||||
size_t size = sizeof(rw);
|
||||
if (!rw) {
|
||||
return 0;
|
||||
}
|
||||
if (rw->context) {
|
||||
size += sizeof(rw->context);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_type = {
|
||||
"ruby_whisper",
|
||||
{0, rb_whisper_free, ruby_whisper_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper, &ruby_whisper_type, rw);
|
||||
rw = ALLOC(ruby_whisper);
|
||||
rw->context = NULL;
|
||||
return obj;
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_normalize_model_path(VALUE model_path)
|
||||
{
|
||||
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
|
||||
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, model_path);
|
||||
if (!NIL_P(pre_converted_model)) {
|
||||
model_path = pre_converted_model;
|
||||
#ifdef RUBY_WHISPER_USE_COREML
|
||||
VALUE coreml_converted_models = rb_funcall(cModel, id_coreml_compiled_models, 0);
|
||||
VALUE coreml_converted_model = rb_hash_aref(coreml_converted_models, pre_converted_model);
|
||||
if (!NIL_P(coreml_converted_model)) {
|
||||
rb_funcall(coreml_converted_model, id_cache, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if (TYPE(model_path) == T_STRING) {
|
||||
const char * model_path_str = StringValueCStr(model_path);
|
||||
if (strncmp("http://", model_path_str, 7) == 0 || strncmp("https://", model_path_str, 8) == 0) {
|
||||
VALUE uri_class = rb_const_get(cModel, id_URI);
|
||||
model_path = rb_class_new_instance(1, &model_path, uri_class);
|
||||
}
|
||||
}
|
||||
else if (rb_obj_is_kind_of(model_path, rb_path2class("URI::HTTP"))) {
|
||||
VALUE uri_class = rb_const_get(cModel, id_URI);
|
||||
model_path = rb_class_new_instance(1, &model_path, uri_class);
|
||||
}
|
||||
if (rb_respond_to(model_path, id_to_path)) {
|
||||
model_path = rb_funcall(model_path, id_to_path, 0);
|
||||
}
|
||||
|
||||
return model_path;
|
||||
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -127,9 +66,27 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
|
||||
|
||||
// TODO: we can support init from buffer here too maybe another ruby object to expose
|
||||
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
|
||||
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
|
||||
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
|
||||
if (!NIL_P(pre_converted_model)) {
|
||||
whisper_model_file_path = pre_converted_model;
|
||||
}
|
||||
if (TYPE(whisper_model_file_path) == T_STRING) {
|
||||
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
|
||||
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
|
||||
VALUE uri_class = rb_const_get(cModel, id_URI);
|
||||
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
|
||||
}
|
||||
}
|
||||
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
|
||||
VALUE uri_class = rb_const_get(cModel, id_URI);
|
||||
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
|
||||
}
|
||||
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
|
||||
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
|
||||
}
|
||||
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
|
||||
}
|
||||
@ -147,7 +104,7 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
|
||||
VALUE ruby_whisper_model_n_vocab(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_vocab(rw->context));
|
||||
}
|
||||
|
||||
@ -158,7 +115,7 @@ VALUE ruby_whisper_model_n_vocab(VALUE self)
|
||||
VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
|
||||
}
|
||||
|
||||
@ -169,7 +126,7 @@ VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
|
||||
VALUE ruby_whisper_model_n_audio_state(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_state(rw->context));
|
||||
}
|
||||
|
||||
@ -180,7 +137,7 @@ VALUE ruby_whisper_model_n_audio_state(VALUE self)
|
||||
VALUE ruby_whisper_model_n_audio_head(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_head(rw->context));
|
||||
}
|
||||
|
||||
@ -191,7 +148,7 @@ VALUE ruby_whisper_model_n_audio_head(VALUE self)
|
||||
VALUE ruby_whisper_model_n_audio_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_layer(rw->context));
|
||||
}
|
||||
|
||||
@ -202,7 +159,7 @@ VALUE ruby_whisper_model_n_audio_layer(VALUE self)
|
||||
VALUE ruby_whisper_model_n_text_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_ctx(rw->context));
|
||||
}
|
||||
|
||||
@ -213,7 +170,7 @@ VALUE ruby_whisper_model_n_text_ctx(VALUE self)
|
||||
VALUE ruby_whisper_model_n_text_state(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_state(rw->context));
|
||||
}
|
||||
|
||||
@ -224,7 +181,7 @@ VALUE ruby_whisper_model_n_text_state(VALUE self)
|
||||
VALUE ruby_whisper_model_n_text_head(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_head(rw->context));
|
||||
}
|
||||
|
||||
@ -235,7 +192,7 @@ VALUE ruby_whisper_model_n_text_head(VALUE self)
|
||||
VALUE ruby_whisper_model_n_text_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_layer(rw->context));
|
||||
}
|
||||
|
||||
@ -246,7 +203,7 @@ VALUE ruby_whisper_model_n_text_layer(VALUE self)
|
||||
VALUE ruby_whisper_model_n_mels(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_mels(rw->context));
|
||||
}
|
||||
|
||||
@ -257,7 +214,7 @@ VALUE ruby_whisper_model_n_mels(VALUE self)
|
||||
VALUE ruby_whisper_model_ftype(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_ftype(rw->context));
|
||||
}
|
||||
|
||||
@ -268,7 +225,7 @@ VALUE ruby_whisper_model_ftype(VALUE self)
|
||||
VALUE ruby_whisper_model_type(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return rb_str_new2(whisper_model_type_readable(rw->context));
|
||||
}
|
||||
|
||||
@ -291,9 +248,9 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
VALUE params = argv[0];
|
||||
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
VALUE samples = argv[1];
|
||||
int n_samples;
|
||||
rb_memory_view_t view;
|
||||
@ -308,20 +265,13 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
|
||||
// Should check when samples.respond_to?(:length)?
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
if (RARRAY_LEN(samples) > INT_MAX) {
|
||||
rb_raise(rb_eArgError, "samples are too long");
|
||||
}
|
||||
n_samples = (int)RARRAY_LEN(samples);
|
||||
n_samples = RARRAY_LEN(samples);
|
||||
} else if (memory_view_available_p) {
|
||||
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
|
||||
view.obj = Qnil;
|
||||
rb_raise(rb_eArgError, "unable to get a memory view");
|
||||
}
|
||||
ssize_t n_samples_size = view.byte_size / view.item_size;
|
||||
if (n_samples_size > INT_MAX) {
|
||||
rb_raise(rb_eArgError, "samples are too long");
|
||||
}
|
||||
n_samples = (int)n_samples_size;
|
||||
n_samples = view.byte_size / view.item_size;
|
||||
} else if (rb_respond_to(samples, id_length)) {
|
||||
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
|
||||
} else {
|
||||
@ -346,7 +296,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
|
||||
}
|
||||
}
|
||||
}
|
||||
prepare_transcription(rwp, &self);
|
||||
register_callbacks(rwp, &self);
|
||||
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
|
||||
if (0 == result) {
|
||||
return self;
|
||||
@ -377,9 +327,9 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
VALUE params = argv[0];
|
||||
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
VALUE samples = argv[1];
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
@ -409,17 +359,10 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
|
||||
view.obj = Qnil;
|
||||
rb_raise(rb_eArgError, "unable to get a memory view");
|
||||
}
|
||||
ssize_t n_samples_size = view.byte_size / view.item_size;
|
||||
if (n_samples_size > INT_MAX) {
|
||||
rb_raise(rb_eArgError, "samples are too long");
|
||||
}
|
||||
n_samples = (int)n_samples_size;
|
||||
n_samples = view.byte_size / view.item_size;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
if (RARRAY_LEN(samples) > INT_MAX) {
|
||||
rb_raise(rb_eArgError, "samples are too long");
|
||||
}
|
||||
n_samples = (int)RARRAY_LEN(samples);
|
||||
n_samples = RARRAY_LEN(samples);
|
||||
} else if (rb_respond_to(samples, id_length)) {
|
||||
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
|
||||
} else {
|
||||
@ -444,7 +387,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
|
||||
}
|
||||
}
|
||||
}
|
||||
prepare_transcription(rwp, &self);
|
||||
register_callbacks(rwp, &self);
|
||||
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
|
||||
if (0 == result) {
|
||||
return self;
|
||||
@ -463,7 +406,7 @@ static VALUE
|
||||
ruby_whisper_full_n_segments(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_full_n_segments(rw->context));
|
||||
}
|
||||
|
||||
@ -477,7 +420,7 @@ static VALUE
|
||||
ruby_whisper_full_lang_id(VALUE self)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_full_lang_id(rw->context));
|
||||
}
|
||||
|
||||
@ -502,10 +445,10 @@ static VALUE
|
||||
ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
|
||||
return LONG2NUM(t0);
|
||||
return INT2NUM(t0);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -520,10 +463,10 @@ static VALUE
|
||||
ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
|
||||
return LONG2NUM(t1);
|
||||
return INT2NUM(t1);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -538,7 +481,7 @@ static VALUE
|
||||
ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
|
||||
return speaker_turn_next ? Qtrue : Qfalse;
|
||||
@ -556,7 +499,7 @@ static VALUE
|
||||
ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
|
||||
return rb_str_new2(text);
|
||||
@ -570,7 +513,7 @@ static VALUE
|
||||
ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
|
||||
{
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
|
||||
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
|
||||
return DBL2NUM(no_speech_prob);
|
||||
@ -581,7 +524,7 @@ ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
|
||||
static VALUE
|
||||
ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
|
||||
{
|
||||
return rb_whisper_segment_s_new(self, NUM2INT(i_segment));
|
||||
return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
|
||||
}
|
||||
|
||||
/*
|
||||
@ -611,11 +554,11 @@ ruby_whisper_each_segment(VALUE self)
|
||||
}
|
||||
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(rw->context);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
rb_yield(rb_whisper_segment_s_new(self, i));
|
||||
rb_yield(rb_whisper_segment_initialize(self, i));
|
||||
}
|
||||
|
||||
return self;
|
||||
@ -628,7 +571,7 @@ ruby_whisper_each_segment(VALUE self)
|
||||
static VALUE
|
||||
ruby_whisper_get_model(VALUE self)
|
||||
{
|
||||
return rb_whisper_model_s_new(self);
|
||||
return rb_whisper_model_initialize(self);
|
||||
}
|
||||
|
||||
void
|
||||
@ -636,8 +579,6 @@ init_ruby_whisper_context(VALUE *mWhisper)
|
||||
{
|
||||
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
|
||||
|
||||
transcribe_option_names[0] = id_n_processors;
|
||||
|
||||
rb_define_alloc_func(cContext, ruby_whisper_allocate);
|
||||
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
|
||||
|
||||
@ -664,7 +605,7 @@ init_ruby_whisper_context(VALUE *mWhisper)
|
||||
rb_define_method(cContext, "full", ruby_whisper_full, -1);
|
||||
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
|
||||
|
||||
// High level
|
||||
// High leve
|
||||
rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
|
||||
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
|
||||
|
||||
|
@ -1,44 +1,22 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_type;
|
||||
|
||||
extern VALUE cModel;
|
||||
|
||||
static void rb_whisper_model_mark(void *p) {
|
||||
ruby_whisper_model *rwm = (ruby_whisper_model *)p;
|
||||
if (rwm->context) {
|
||||
rb_gc_mark(rwm->context);
|
||||
}
|
||||
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
|
||||
rb_gc_mark(rwm->context);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_model_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper_model *rwm = (const ruby_whisper_model *)p;
|
||||
size_t size = sizeof(rwm);
|
||||
if (!rwm) {
|
||||
return 0;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
static const rb_data_type_t rb_whisper_model_type = {
|
||||
"ruby_whisper_model",
|
||||
{rb_whisper_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_model_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE ruby_whisper_model_allocate(VALUE klass) {
|
||||
ruby_whisper_model *rwm;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
rwm = ALLOC(ruby_whisper_model);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
|
||||
}
|
||||
|
||||
VALUE rb_whisper_model_s_new(VALUE context) {
|
||||
VALUE rb_whisper_model_initialize(VALUE context) {
|
||||
ruby_whisper_model *rwm;
|
||||
const VALUE model = ruby_whisper_model_allocate(cModel);
|
||||
TypedData_Get_Struct(model, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(model, ruby_whisper_model, rwm);
|
||||
rwm->context = context;
|
||||
return model;
|
||||
};
|
||||
@ -51,9 +29,9 @@ static VALUE
|
||||
ruby_whisper_model_n_vocab(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_vocab(rw->context));
|
||||
}
|
||||
|
||||
@ -65,9 +43,9 @@ static VALUE
|
||||
ruby_whisper_model_n_audio_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
|
||||
}
|
||||
|
||||
@ -79,9 +57,9 @@ static VALUE
|
||||
ruby_whisper_model_n_audio_state(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_state(rw->context));
|
||||
}
|
||||
|
||||
@ -93,9 +71,9 @@ static VALUE
|
||||
ruby_whisper_model_n_audio_head(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_head(rw->context));
|
||||
}
|
||||
|
||||
@ -107,9 +85,9 @@ static VALUE
|
||||
ruby_whisper_model_n_audio_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_audio_layer(rw->context));
|
||||
}
|
||||
|
||||
@ -121,9 +99,9 @@ static VALUE
|
||||
ruby_whisper_model_n_text_ctx(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_ctx(rw->context));
|
||||
}
|
||||
|
||||
@ -135,9 +113,9 @@ static VALUE
|
||||
ruby_whisper_model_n_text_state(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_state(rw->context));
|
||||
}
|
||||
|
||||
@ -149,9 +127,9 @@ static VALUE
|
||||
ruby_whisper_model_n_text_head(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_head(rw->context));
|
||||
}
|
||||
|
||||
@ -163,9 +141,9 @@ static VALUE
|
||||
ruby_whisper_model_n_text_layer(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_text_layer(rw->context));
|
||||
}
|
||||
|
||||
@ -177,9 +155,9 @@ static VALUE
|
||||
ruby_whisper_model_n_mels(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_n_mels(rw->context));
|
||||
}
|
||||
|
||||
@ -191,9 +169,9 @@ static VALUE
|
||||
ruby_whisper_model_ftype(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return INT2NUM(whisper_model_ftype(rw->context));
|
||||
}
|
||||
|
||||
@ -205,9 +183,9 @@ static VALUE
|
||||
ruby_whisper_model_type(VALUE self)
|
||||
{
|
||||
ruby_whisper_model *rwm;
|
||||
TypedData_Get_Struct(self, ruby_whisper_model, &rb_whisper_model_type, rwm);
|
||||
Data_Get_Struct(self, ruby_whisper_model, rwm);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rwm->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rwm->context, ruby_whisper, rw);
|
||||
return rb_str_new2(whisper_model_type_readable(rw->context));
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#define BOOL_PARAMS_SETTER(self, prop, value) \
|
||||
ruby_whisper_params *rwp; \
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||
if (value == Qfalse || value == Qnil) { \
|
||||
rwp->params.prop = false; \
|
||||
} else { \
|
||||
@ -13,7 +13,7 @@
|
||||
|
||||
#define BOOL_PARAMS_GETTER(self, prop) \
|
||||
ruby_whisper_params *rwp; \
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); \
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp); \
|
||||
if (rwp->params.prop) { \
|
||||
return Qtrue; \
|
||||
} else { \
|
||||
@ -26,16 +26,13 @@
|
||||
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
|
||||
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
|
||||
|
||||
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 35
|
||||
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
|
||||
|
||||
extern VALUE cParams;
|
||||
extern VALUE cVADParams;
|
||||
|
||||
extern ID id_call;
|
||||
|
||||
extern VALUE ruby_whisper_normalize_model_path(VALUE model_path);
|
||||
extern VALUE rb_whisper_segment_s_new(VALUE context, int index);
|
||||
extern const rb_data_type_t ruby_whisper_vad_params_type;
|
||||
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
|
||||
|
||||
static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT];
|
||||
static ID id_language;
|
||||
@ -70,15 +67,10 @@ static ID id_encoder_begin_callback;
|
||||
static ID id_encoder_begin_callback_user_data;
|
||||
static ID id_abort_callback;
|
||||
static ID id_abort_callback_user_data;
|
||||
static ID id_vad;
|
||||
static ID id_vad_model_path;
|
||||
static ID id_vad_params;
|
||||
|
||||
static void
|
||||
rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc)
|
||||
{
|
||||
if (rwc == NULL) return;
|
||||
|
||||
rb_gc_mark(rwc->user_data);
|
||||
rb_gc_mark(rwc->callback);
|
||||
rb_gc_mark(rwc->callbacks);
|
||||
@ -110,7 +102,7 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta
|
||||
const int n_segments = whisper_full_n_segments_from_state(state);
|
||||
for (int i = n_new; i > 0; i--) {
|
||||
int i_segment = n_segments - i;
|
||||
VALUE segment = rb_whisper_segment_s_new(*container->context, i_segment);
|
||||
VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment);
|
||||
for (int j = 0; j < callbacks_len; j++) {
|
||||
VALUE cb = rb_ary_entry(container->callbacks, j);
|
||||
rb_funcall(cb, id_call, 1, segment);
|
||||
@ -185,7 +177,7 @@ static bool abort_callback(void * user_data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) {
|
||||
rwp->new_segment_callback_container->context = context;
|
||||
rwp->params.new_segment_callback = new_segment_callback;
|
||||
@ -211,29 +203,13 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
|
||||
}
|
||||
}
|
||||
|
||||
static void set_vad_params(ruby_whisper_params *rwp)
|
||||
{
|
||||
ruby_whisper_vad_params * rwvp;
|
||||
TypedData_Get_Struct(rwp->vad_params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwp->params.vad_params = rwvp->params;
|
||||
}
|
||||
|
||||
void
|
||||
prepare_transcription(ruby_whisper_params *rwp, VALUE *context)
|
||||
rb_whisper_params_mark(ruby_whisper_params *rwp)
|
||||
{
|
||||
register_callbacks(rwp, context);
|
||||
set_vad_params(rwp);
|
||||
}
|
||||
|
||||
void
|
||||
rb_whisper_params_mark(void *p)
|
||||
{
|
||||
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
|
||||
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
|
||||
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
|
||||
rb_gc_mark(rwp->vad_params);
|
||||
}
|
||||
|
||||
void
|
||||
@ -242,46 +218,25 @@ ruby_whisper_params_free(ruby_whisper_params *rwp)
|
||||
}
|
||||
|
||||
void
|
||||
rb_whisper_params_free(void *p)
|
||||
rb_whisper_params_free(ruby_whisper_params *rwp)
|
||||
{
|
||||
ruby_whisper_params *rwp = (ruby_whisper_params *)p;
|
||||
// How to free user_data and callback only when not referred to by others?
|
||||
ruby_whisper_params_free(rwp);
|
||||
free(rwp);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_params_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper_params *rwp = (const ruby_whisper_params *)p;
|
||||
|
||||
return sizeof(ruby_whisper_params) + sizeof(rwp->params) + sizeof(rwp->vad_params);
|
||||
}
|
||||
|
||||
const rb_data_type_t ruby_whisper_params_type = {
|
||||
"ruby_whisper_params",
|
||||
{
|
||||
rb_whisper_params_mark,
|
||||
rb_whisper_params_free,
|
||||
ruby_whisper_params_memsize,
|
||||
},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_params_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
rwp = ALLOC(ruby_whisper_params);
|
||||
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
rwp->diarize = false;
|
||||
rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params);
|
||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
|
||||
return obj;
|
||||
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -294,7 +249,7 @@ static VALUE
|
||||
ruby_whisper_params_set_language(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->params.language = "auto";
|
||||
} else {
|
||||
@ -310,7 +265,7 @@ static VALUE
|
||||
ruby_whisper_params_get_language(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (rwp->params.language) {
|
||||
return rb_str_new2(rwp->params.language);
|
||||
} else {
|
||||
@ -547,7 +502,7 @@ static VALUE
|
||||
ruby_whisper_params_get_initial_prompt(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt);
|
||||
}
|
||||
/*
|
||||
@ -558,7 +513,7 @@ static VALUE
|
||||
ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.initial_prompt = StringValueCStr(value);
|
||||
return value;
|
||||
}
|
||||
@ -572,7 +527,7 @@ static VALUE
|
||||
ruby_whisper_params_get_diarize(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (rwp->diarize) {
|
||||
return Qtrue;
|
||||
} else {
|
||||
@ -587,7 +542,7 @@ static VALUE
|
||||
ruby_whisper_params_set_diarize(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
if (value == Qfalse || value == Qnil) {
|
||||
rwp->diarize = false;
|
||||
} else {
|
||||
@ -606,7 +561,7 @@ static VALUE
|
||||
ruby_whisper_params_get_offset(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.offset_ms);
|
||||
}
|
||||
/*
|
||||
@ -617,7 +572,7 @@ static VALUE
|
||||
ruby_whisper_params_set_offset(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.offset_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
@ -631,7 +586,7 @@ static VALUE
|
||||
ruby_whisper_params_get_duration(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.duration_ms);
|
||||
}
|
||||
/*
|
||||
@ -642,7 +597,7 @@ static VALUE
|
||||
ruby_whisper_params_set_duration(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.duration_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
@ -657,7 +612,7 @@ static VALUE
|
||||
ruby_whisper_params_get_max_text_tokens(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return INT2NUM(rwp->params.n_max_text_ctx);
|
||||
}
|
||||
/*
|
||||
@ -668,7 +623,7 @@ static VALUE
|
||||
ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.n_max_text_ctx = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
@ -680,7 +635,7 @@ static VALUE
|
||||
ruby_whisper_params_get_temperature(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.temperature);
|
||||
}
|
||||
/*
|
||||
@ -691,7 +646,7 @@ static VALUE
|
||||
ruby_whisper_params_set_temperature(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.temperature = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -705,7 +660,7 @@ static VALUE
|
||||
ruby_whisper_params_get_max_initial_ts(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.max_initial_ts);
|
||||
}
|
||||
/*
|
||||
@ -716,7 +671,7 @@ static VALUE
|
||||
ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.max_initial_ts = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -728,7 +683,7 @@ static VALUE
|
||||
ruby_whisper_params_get_length_penalty(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.length_penalty);
|
||||
}
|
||||
/*
|
||||
@ -739,7 +694,7 @@ static VALUE
|
||||
ruby_whisper_params_set_length_penalty(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.length_penalty = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -751,7 +706,7 @@ static VALUE
|
||||
ruby_whisper_params_get_temperature_inc(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.temperature_inc);
|
||||
}
|
||||
/*
|
||||
@ -762,7 +717,7 @@ static VALUE
|
||||
ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.temperature_inc = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -776,7 +731,7 @@ static VALUE
|
||||
ruby_whisper_params_get_entropy_thold(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.entropy_thold);
|
||||
}
|
||||
/*
|
||||
@ -787,7 +742,7 @@ static VALUE
|
||||
ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.entropy_thold = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -799,7 +754,7 @@ static VALUE
|
||||
ruby_whisper_params_get_logprob_thold(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.logprob_thold);
|
||||
}
|
||||
/*
|
||||
@ -810,7 +765,7 @@ static VALUE
|
||||
ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.logprob_thold = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -822,7 +777,7 @@ static VALUE
|
||||
ruby_whisper_params_get_no_speech_thold(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.no_speech_thold);
|
||||
}
|
||||
/*
|
||||
@ -833,7 +788,7 @@ static VALUE
|
||||
ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
@ -841,7 +796,7 @@ static VALUE
|
||||
ruby_whisper_params_get_new_segment_callback(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->new_segment_callback_container->callback;
|
||||
}
|
||||
/*
|
||||
@ -858,7 +813,7 @@ static VALUE
|
||||
ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->new_segment_callback_container->callback = value;
|
||||
return value;
|
||||
}
|
||||
@ -866,7 +821,7 @@ static VALUE
|
||||
ruby_whisper_params_get_new_segment_callback_user_data(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->new_segment_callback_container->user_data;
|
||||
}
|
||||
/*
|
||||
@ -879,7 +834,7 @@ static VALUE
|
||||
ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->new_segment_callback_container->user_data = value;
|
||||
return value;
|
||||
}
|
||||
@ -887,7 +842,7 @@ static VALUE
|
||||
ruby_whisper_params_get_progress_callback(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->progress_callback_container->callback;
|
||||
}
|
||||
/*
|
||||
@ -906,7 +861,7 @@ static VALUE
|
||||
ruby_whisper_params_set_progress_callback(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->progress_callback_container->callback = value;
|
||||
return value;
|
||||
}
|
||||
@ -914,7 +869,7 @@ static VALUE
|
||||
ruby_whisper_params_get_progress_callback_user_data(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->progress_callback_container->user_data;
|
||||
}
|
||||
/*
|
||||
@ -927,7 +882,7 @@ static VALUE
|
||||
ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->progress_callback_container->user_data = value;
|
||||
return value;
|
||||
}
|
||||
@ -936,7 +891,7 @@ static VALUE
|
||||
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->encoder_begin_callback_container->callback;
|
||||
}
|
||||
|
||||
@ -954,7 +909,7 @@ static VALUE
|
||||
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->encoder_begin_callback_container->callback = value;
|
||||
return value;
|
||||
}
|
||||
@ -963,7 +918,7 @@ static VALUE
|
||||
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->encoder_begin_callback_container->user_data;
|
||||
}
|
||||
|
||||
@ -977,7 +932,7 @@ static VALUE
|
||||
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->encoder_begin_callback_container->user_data = value;
|
||||
return value;
|
||||
}
|
||||
@ -986,7 +941,7 @@ static VALUE
|
||||
ruby_whisper_params_get_abort_callback(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->abort_callback_container->callback;
|
||||
}
|
||||
/*
|
||||
@ -1003,7 +958,7 @@ static VALUE
|
||||
ruby_whisper_params_set_abort_callback(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->abort_callback_container->callback = value;
|
||||
return value;
|
||||
}
|
||||
@ -1011,7 +966,7 @@ static VALUE
|
||||
ruby_whisper_params_get_abort_callback_user_data(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return rwp->abort_callback_container->user_data;
|
||||
}
|
||||
/*
|
||||
@ -1024,74 +979,11 @@ static VALUE
|
||||
ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->abort_callback_container->user_data = value;
|
||||
return value;
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* vad = use_vad -> use_vad
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_params_get_vad(VALUE self)
|
||||
{
|
||||
BOOL_PARAMS_GETTER(self, vad)
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_params_set_vad(VALUE self, VALUE value)
|
||||
{
|
||||
BOOL_PARAMS_SETTER(self, vad, value)
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* vad_model_path = model_path -> model_path
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
if (NIL_P(value)) {
|
||||
rwp->params.vad_model_path = NULL;
|
||||
return value;
|
||||
}
|
||||
VALUE path = ruby_whisper_normalize_model_path(value);
|
||||
rwp->params.vad_model_path = StringValueCStr(path);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_params_get_vad_model_path(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
return rwp->params.vad_model_path == NULL ? Qnil : rb_str_new2(rwp->params.vad_model_path);
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* vad_params = params -> params
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_params_set_vad_params(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
rwp->vad_params = value;
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_params_get_vad_params(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
return rwp->vad_params;
|
||||
}
|
||||
|
||||
#define SET_PARAM_IF_SAME(param_name) \
|
||||
if (id == id_ ## param_name) { \
|
||||
ruby_whisper_params_set_ ## param_name(self, value); \
|
||||
@ -1101,6 +993,7 @@ ruby_whisper_params_get_vad_params(VALUE self)
|
||||
static VALUE
|
||||
ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
|
||||
VALUE kw_hash;
|
||||
VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef};
|
||||
VALUE value;
|
||||
@ -1114,7 +1007,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, values);
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
|
||||
for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) {
|
||||
id = param_names[i];
|
||||
@ -1157,9 +1050,6 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
|
||||
SET_PARAM_IF_SAME(abort_callback)
|
||||
SET_PARAM_IF_SAME(abort_callback_user_data)
|
||||
SET_PARAM_IF_SAME(vad)
|
||||
SET_PARAM_IF_SAME(vad_model_path)
|
||||
SET_PARAM_IF_SAME(vad_params)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1181,10 +1071,10 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
static VALUE
|
||||
ruby_whisper_params_on_new_segment(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
ruby_whisper_params *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||
const VALUE blk = rb_block_proc();
|
||||
rb_ary_push(rwp->new_segment_callback_container->callbacks, blk);
|
||||
rb_ary_push(rws->new_segment_callback_container->callbacks, blk);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
@ -1201,10 +1091,10 @@ ruby_whisper_params_on_new_segment(VALUE self)
|
||||
static VALUE
|
||||
ruby_whisper_params_on_progress(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
ruby_whisper_params *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||
const VALUE blk = rb_block_proc();
|
||||
rb_ary_push(rwp->progress_callback_container->callbacks, blk);
|
||||
rb_ary_push(rws->progress_callback_container->callbacks, blk);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
@ -1221,10 +1111,10 @@ ruby_whisper_params_on_progress(VALUE self)
|
||||
static VALUE
|
||||
ruby_whisper_params_on_encoder_begin(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
ruby_whisper_params *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||
const VALUE blk = rb_block_proc();
|
||||
rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk);
|
||||
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
@ -1245,10 +1135,10 @@ ruby_whisper_params_on_encoder_begin(VALUE self)
|
||||
static VALUE
|
||||
ruby_whisper_params_abort_on(VALUE self)
|
||||
{
|
||||
ruby_whisper_params *rwp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
ruby_whisper_params *rws;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rws);
|
||||
const VALUE blk = rb_block_proc();
|
||||
rb_ary_push(rwp->abort_callback_container->callbacks, blk);
|
||||
rb_ary_push(rws->abort_callback_container->callbacks, blk);
|
||||
return Qnil;
|
||||
}
|
||||
|
||||
@ -1292,9 +1182,6 @@ init_ruby_whisper_params(VALUE *mWhisper)
|
||||
DEFINE_PARAM(encoder_begin_callback_user_data, 29)
|
||||
DEFINE_PARAM(abort_callback, 30)
|
||||
DEFINE_PARAM(abort_callback_user_data, 31)
|
||||
DEFINE_PARAM(vad, 32)
|
||||
DEFINE_PARAM(vad_model_path, 33)
|
||||
DEFINE_PARAM(vad_params, 34)
|
||||
|
||||
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
|
||||
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
|
||||
|
@ -1,57 +1,28 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
#define N_KEY_NAMES 5
|
||||
|
||||
static VALUE sym_start_time;
|
||||
static VALUE sym_end_time;
|
||||
static VALUE sym_text;
|
||||
static VALUE sym_no_speech_prob;
|
||||
static VALUE sym_speaker_turn_next;
|
||||
static VALUE key_names;
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_type;
|
||||
|
||||
extern VALUE cSegment;
|
||||
|
||||
static void
|
||||
rb_whisper_segment_mark(void *p)
|
||||
rb_whisper_segment_mark(ruby_whisper_segment *rws)
|
||||
{
|
||||
ruby_whisper_segment *rws = (ruby_whisper_segment *)p;
|
||||
rb_gc_mark(rws->context);
|
||||
}
|
||||
|
||||
static size_t
|
||||
ruby_whisper_segment_memsize(const void *p)
|
||||
{
|
||||
const ruby_whisper_segment *rws = (const ruby_whisper_segment *)p;
|
||||
size_t size = sizeof(rws);
|
||||
if (!rws) {
|
||||
return 0;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
static const rb_data_type_t ruby_whisper_segment_type = {
|
||||
"ruby_whisper_segment",
|
||||
{rb_whisper_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_segment_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
VALUE
|
||||
ruby_whisper_segment_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
return TypedData_Make_Struct(klass, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
rws = ALLOC(ruby_whisper_segment);
|
||||
return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
|
||||
}
|
||||
|
||||
VALUE
|
||||
rb_whisper_segment_s_new(VALUE context, int index)
|
||||
rb_whisper_segment_initialize(VALUE context, int index)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
const VALUE segment = ruby_whisper_segment_allocate(cSegment);
|
||||
TypedData_Get_Struct(segment, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
Data_Get_Struct(segment, ruby_whisper_segment, rws);
|
||||
rws->context = context;
|
||||
rws->index = index;
|
||||
return segment;
|
||||
@ -67,12 +38,12 @@ static VALUE
|
||||
ruby_whisper_segment_get_start_time(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
|
||||
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
|
||||
return LONG2NUM(t0 * 10);
|
||||
return INT2NUM(t0 * 10);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -85,12 +56,12 @@ static VALUE
|
||||
ruby_whisper_segment_get_end_time(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
|
||||
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
|
||||
return LONG2NUM(t1 * 10);
|
||||
return INT2NUM(t1 * 10);
|
||||
}
|
||||
|
||||
/*
|
||||
@ -103,9 +74,9 @@ static VALUE
|
||||
ruby_whisper_segment_get_speaker_turn_next(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
|
||||
}
|
||||
|
||||
@ -117,9 +88,9 @@ static VALUE
|
||||
ruby_whisper_segment_get_text(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
const char * text = whisper_full_get_segment_text(rw->context, rws->index);
|
||||
return rb_str_new2(text);
|
||||
}
|
||||
@ -132,89 +103,21 @@ static VALUE
|
||||
ruby_whisper_segment_get_no_speech_prob(VALUE self)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
Data_Get_Struct(self, ruby_whisper_segment, rws);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
Data_Get_Struct(rws->context, ruby_whisper, rw);
|
||||
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
|
||||
}
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* deconstruct_keys(keys) -> hash
|
||||
*
|
||||
* Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
|
||||
*
|
||||
* whisper.each_segment do |segment|
|
||||
* segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
|
||||
*
|
||||
* puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
|
||||
* end
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_segment_deconstruct_keys(VALUE self, VALUE keys)
|
||||
{
|
||||
ruby_whisper_segment *rws;
|
||||
TypedData_Get_Struct(self, ruby_whisper_segment, &ruby_whisper_segment_type, rws);
|
||||
ruby_whisper *rw;
|
||||
TypedData_Get_Struct(rws->context, ruby_whisper, &ruby_whisper_type, rw);
|
||||
|
||||
VALUE hash = rb_hash_new();
|
||||
long n_keys;
|
||||
if (NIL_P(keys)) {
|
||||
keys = key_names;
|
||||
n_keys = N_KEY_NAMES;
|
||||
} else {
|
||||
n_keys = RARRAY_LEN(keys);
|
||||
if (n_keys > N_KEY_NAMES) {
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < n_keys; i++) {
|
||||
VALUE key = rb_ary_entry(keys, i);
|
||||
if (key == sym_start_time) {
|
||||
rb_hash_aset(hash, key, ruby_whisper_segment_get_start_time(self));
|
||||
}
|
||||
if (key == sym_end_time) {
|
||||
rb_hash_aset(hash, key, ruby_whisper_segment_get_end_time(self));
|
||||
}
|
||||
if (key == sym_text) {
|
||||
rb_hash_aset(hash, key, ruby_whisper_segment_get_text(self));
|
||||
}
|
||||
if (key == sym_no_speech_prob) {
|
||||
rb_hash_aset(hash, key, ruby_whisper_segment_get_no_speech_prob(self));
|
||||
}
|
||||
if (key == sym_speaker_turn_next) {
|
||||
rb_hash_aset(hash, key, ruby_whisper_segment_get_speaker_turn_next(self));
|
||||
}
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
void
|
||||
init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
|
||||
{
|
||||
cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
|
||||
|
||||
sym_start_time = ID2SYM(rb_intern("start_time"));
|
||||
sym_end_time = ID2SYM(rb_intern("end_time"));
|
||||
sym_text = ID2SYM(rb_intern("text"));
|
||||
sym_no_speech_prob = ID2SYM(rb_intern("no_speech_prob"));
|
||||
sym_speaker_turn_next = ID2SYM(rb_intern("speaker_turn_next"));
|
||||
key_names = rb_ary_new3(
|
||||
N_KEY_NAMES,
|
||||
sym_start_time,
|
||||
sym_end_time,
|
||||
sym_text,
|
||||
sym_no_speech_prob,
|
||||
sym_speaker_turn_next
|
||||
);
|
||||
|
||||
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
|
||||
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
|
||||
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
|
||||
rb_define_method(cSegment, "speaker_turn_next?", ruby_whisper_segment_get_speaker_turn_next, 0);
|
||||
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
|
||||
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
|
||||
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
|
||||
rb_define_method(cSegment, "deconstruct_keys", ruby_whisper_segment_deconstruct_keys, 1);
|
||||
}
|
||||
|
@ -8,15 +8,11 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
extern const rb_data_type_t ruby_whisper_type;
|
||||
extern const rb_data_type_t ruby_whisper_params_type;
|
||||
|
||||
extern ID id_to_s;
|
||||
extern ID id_call;
|
||||
extern ID transcribe_option_names[1];
|
||||
|
||||
extern void
|
||||
prepare_transcription(ruby_whisper_params * rwp, VALUE * self);
|
||||
register_callbacks(ruby_whisper_params * rwp, VALUE * self);
|
||||
|
||||
/*
|
||||
* transcribe a single file
|
||||
@ -35,16 +31,11 @@ VALUE
|
||||
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
VALUE wave_file_path, blk, params, kws;
|
||||
VALUE opts[1];
|
||||
VALUE wave_file_path, blk, params;
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_LAST_HASH_KEYWORDS, argc, argv, "2:&", &wave_file_path, ¶ms, &kws, &blk);
|
||||
rb_get_kwargs(kws, transcribe_option_names, 0, 1, opts);
|
||||
|
||||
int n_processors = opts[0] == Qundef ? 1 : NUM2INT(opts[0]);
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
|
||||
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
|
||||
rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
|
||||
if (!rb_respond_to(wave_file_path, id_to_s)) {
|
||||
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
|
||||
@ -70,22 +61,22 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
|
||||
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
|
||||
// }
|
||||
|
||||
prepare_transcription(rwp, &self);
|
||||
register_callbacks(rwp, &self);
|
||||
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) {
|
||||
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return self;
|
||||
}
|
||||
if (NIL_P(blk)) {
|
||||
return self;
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments(rw->context);
|
||||
VALUE output = rb_str_new2("");
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(rw->context, i);
|
||||
output = rb_str_concat(output, rb_str_new2(text));
|
||||
}
|
||||
rb_funcall(blk, id_call, 1, output);
|
||||
VALUE idCall = id_call;
|
||||
if (blk != Qnil) {
|
||||
rb_funcall(blk, idCall, 1, output);
|
||||
}
|
||||
return self;
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
|
@ -1,288 +0,0 @@
|
||||
#include <ruby.h>
|
||||
#include "ruby_whisper.h"
|
||||
|
||||
#define DEFINE_PARAM(param_name, nth) \
|
||||
id_ ## param_name = rb_intern(#param_name); \
|
||||
param_names[nth] = id_ ## param_name; \
|
||||
rb_define_method(cVADParams, #param_name, ruby_whisper_vad_params_get_ ## param_name, 0); \
|
||||
rb_define_method(cVADParams, #param_name "=", ruby_whisper_vad_params_set_ ## param_name, 1);
|
||||
|
||||
#define NUM_PARAMS 6
|
||||
|
||||
extern VALUE cVADParams;
|
||||
|
||||
static size_t
|
||||
ruby_whisper_vad_params_memsize(const void *p)
|
||||
{
|
||||
const struct ruby_whisper_vad_params *params = p;
|
||||
size_t size = sizeof(params);
|
||||
if (!params) {
|
||||
return 0;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
static ID param_names[NUM_PARAMS];
|
||||
static ID id_threshold;
|
||||
static ID id_min_speech_duration_ms;
|
||||
static ID id_min_silence_duration_ms;
|
||||
static ID id_max_speech_duration_s;
|
||||
static ID id_speech_pad_ms;
|
||||
static ID id_samples_overlap;
|
||||
|
||||
const rb_data_type_t ruby_whisper_vad_params_type = {
|
||||
"ruby_whisper_vad_params",
|
||||
{0, 0, ruby_whisper_vad_params_memsize,},
|
||||
0, 0,
|
||||
0
|
||||
};
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_s_allocate(VALUE klass)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params = whisper_vad_default_params();
|
||||
return obj;
|
||||
}
|
||||
|
||||
/*
|
||||
* Probability threshold to consider as speech.
|
||||
*
|
||||
* call-seq:
|
||||
* threshold = th -> th
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_set_threshold(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params.threshold = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_get_threshold(VALUE self)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
return DBL2NUM(rwvp->params.threshold);
|
||||
}
|
||||
|
||||
/*
|
||||
* Min duration for a valid speech segment.
|
||||
*
|
||||
* call-seq:
|
||||
* min_speech_duration_ms = duration_ms -> duration_ms
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_set_min_speech_duration_ms(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params.min_speech_duration_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_get_min_speech_duration_ms(VALUE self)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
return INT2NUM(rwvp->params.min_speech_duration_ms);
|
||||
}
|
||||
|
||||
/*
|
||||
* Min silence duration to consider speech as ended.
|
||||
*
|
||||
* call-seq:
|
||||
* min_silence_duration_ms = duration_ms -> duration_ms
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_set_min_silence_duration_ms(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params.min_silence_duration_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_get_min_silence_duration_ms(VALUE self)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
return INT2NUM(rwvp->params.min_silence_duration_ms);
|
||||
}
|
||||
|
||||
/*
|
||||
* Max duration of a speech segment before forcing a new segment.
|
||||
*
|
||||
* call-seq:
|
||||
* max_speech_duration_s = duration_s -> duration_s
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_set_max_speech_duration_s(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params.max_speech_duration_s = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_get_max_speech_duration_s(VALUE self)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
return DBL2NUM(rwvp->params.max_speech_duration_s);
|
||||
}
|
||||
|
||||
/*
|
||||
* Padding added before and after speech segments.
|
||||
*
|
||||
* call-seq:
|
||||
* speech_pad_ms = pad_ms -> pad_ms
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_set_speech_pad_ms(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params.speech_pad_ms = NUM2INT(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_get_speech_pad_ms(VALUE self)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
return INT2NUM(rwvp->params.speech_pad_ms);
|
||||
}
|
||||
|
||||
/*
|
||||
* Overlap in seconds when copying audio samples from speech segment.
|
||||
*
|
||||
* call-seq:
|
||||
* samples_overlap = overlap -> overlap
|
||||
*/
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_set_samples_overlap(VALUE self, VALUE value)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
rwvp->params.samples_overlap = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_get_samples_overlap(VALUE self)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
return DBL2NUM(rwvp->params.samples_overlap);
|
||||
}
|
||||
|
||||
static VALUE
|
||||
ruby_whisper_vad_params_equal(VALUE self, VALUE other)
|
||||
{
|
||||
ruby_whisper_vad_params *rwvp1;
|
||||
ruby_whisper_vad_params *rwvp2;
|
||||
|
||||
if (self == other) {
|
||||
return Qtrue;
|
||||
}
|
||||
|
||||
if (!rb_obj_is_kind_of(other, cVADParams)) {
|
||||
return Qfalse;
|
||||
}
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp1);
|
||||
TypedData_Get_Struct(other, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp2);
|
||||
|
||||
if (rwvp1->params.threshold != rwvp2->params.threshold) {
|
||||
return Qfalse;
|
||||
}
|
||||
if (rwvp1->params.min_speech_duration_ms != rwvp2->params.min_speech_duration_ms) {
|
||||
return Qfalse;
|
||||
}
|
||||
if (rwvp1->params.min_silence_duration_ms != rwvp2->params.min_silence_duration_ms) {
|
||||
return Qfalse;
|
||||
}
|
||||
if (rwvp1->params.max_speech_duration_s != rwvp2->params.max_speech_duration_s) {
|
||||
return Qfalse;
|
||||
}
|
||||
if (rwvp1->params.speech_pad_ms != rwvp2->params.speech_pad_ms) {
|
||||
return Qfalse;
|
||||
}
|
||||
if (rwvp1->params.samples_overlap != rwvp2->params.samples_overlap) {
|
||||
return Qfalse;
|
||||
}
|
||||
|
||||
return Qtrue;
|
||||
}
|
||||
|
||||
#define SET_PARAM_IF_SAME(param_name) \
|
||||
if (id == id_ ## param_name) { \
|
||||
ruby_whisper_vad_params_set_ ## param_name(self, value); \
|
||||
continue; \
|
||||
}
|
||||
|
||||
VALUE
|
||||
ruby_whisper_vad_params_initialize(int argc, VALUE *argv, VALUE self)
|
||||
{
|
||||
VALUE kw_hash;
|
||||
VALUE values[NUM_PARAMS] = {Qundef};
|
||||
VALUE value;
|
||||
ruby_whisper_vad_params *rwvp;
|
||||
ID id;
|
||||
int i;
|
||||
|
||||
TypedData_Get_Struct(self, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp);
|
||||
|
||||
rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash);
|
||||
if (NIL_P(kw_hash)) {
|
||||
return self;
|
||||
}
|
||||
|
||||
rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values);
|
||||
|
||||
for (i = 0; i < NUM_PARAMS; i++) {
|
||||
id = param_names[i];
|
||||
value = values[i];
|
||||
if (value == Qundef) {
|
||||
continue;
|
||||
}
|
||||
SET_PARAM_IF_SAME(threshold)
|
||||
SET_PARAM_IF_SAME(min_speech_duration_ms)
|
||||
SET_PARAM_IF_SAME(min_silence_duration_ms)
|
||||
SET_PARAM_IF_SAME(max_speech_duration_s)
|
||||
SET_PARAM_IF_SAME(speech_pad_ms)
|
||||
SET_PARAM_IF_SAME(samples_overlap)
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
#undef SET_PARAM_IF_SAME
|
||||
|
||||
void
|
||||
init_ruby_whisper_vad_params(VALUE *mVAD)
|
||||
{
|
||||
cVADParams = rb_define_class_under(*mVAD, "Params", rb_cObject);
|
||||
rb_define_alloc_func(cVADParams, ruby_whisper_vad_params_s_allocate);
|
||||
rb_define_method(cVADParams, "initialize", ruby_whisper_vad_params_initialize, -1);
|
||||
|
||||
DEFINE_PARAM(threshold, 0)
|
||||
DEFINE_PARAM(min_speech_duration_ms, 1)
|
||||
DEFINE_PARAM(min_silence_duration_ms, 2)
|
||||
DEFINE_PARAM(max_speech_duration_s, 3)
|
||||
DEFINE_PARAM(speech_pad_ms, 4)
|
||||
DEFINE_PARAM(samples_overlap, 5)
|
||||
|
||||
rb_define_method(cVADParams, "==", ruby_whisper_vad_params_equal, 1);
|
||||
}
|
||||
|
||||
#undef DEFINE_PARAM
|
||||
#undef NUM_PARAMS
|
@ -1,10 +1,5 @@
|
||||
require "pathname"
|
||||
|
||||
root = Pathname("..")/".."
|
||||
ignored_dirs = %w[
|
||||
.devops
|
||||
.github
|
||||
ci
|
||||
examples/wchess/wchess.wasm
|
||||
examples/whisper.android
|
||||
examples/whisper.android.java
|
||||
@ -14,7 +9,7 @@ ignored_dirs = %w[
|
||||
models
|
||||
samples
|
||||
scripts
|
||||
].collect {|dir| root/dir}
|
||||
]
|
||||
ignored_files = %w[
|
||||
AUTHORS
|
||||
Makefile
|
||||
@ -22,19 +17,18 @@ ignored_files = %w[
|
||||
README_sycl.md
|
||||
.gitignore
|
||||
.gitmodules
|
||||
.dockerignore
|
||||
whisper.nvim
|
||||
twitch.sh
|
||||
yt-wsp.sh
|
||||
close-issue.yml
|
||||
]
|
||||
|
||||
EXTSOURCES =
|
||||
`git ls-files -z #{root}`.split("\x0")
|
||||
.collect {|file| Pathname(file)}
|
||||
.reject {|file|
|
||||
ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} ||
|
||||
ignored_files.include?(file.basename.to_path) ||
|
||||
(file.descend.to_a[1] != root && file.descend.to_a[1] != Pathname("..")/"javascript")
|
||||
`git ls-files -z ../..`.split("\x0")
|
||||
.select {|file|
|
||||
basename = File.basename(file)
|
||||
|
||||
ignored_dirs.all? {|dir| !file.start_with?("../../#{dir}")} &&
|
||||
!ignored_files.include?(basename) &&
|
||||
(file.start_with?("../..") || file.start_with?("../javascript")) &&
|
||||
(!file.start_with?("../../.github/") || basename == "bindings-ruby.yml")
|
||||
}
|
||||
.collect(&:to_path)
|
||||
|
@ -1,15 +0,0 @@
|
||||
module Whisper
|
||||
class Context
|
||||
def to_srt
|
||||
each_segment.with_index.reduce("") {|srt, (segment, index)|
|
||||
srt << "#{index + 1}\n#{segment.to_srt_cue}\n"
|
||||
}
|
||||
end
|
||||
|
||||
def to_webvtt
|
||||
each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)|
|
||||
webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n"
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
@ -130,44 +130,6 @@ module Whisper
|
||||
end
|
||||
end
|
||||
|
||||
class ZipURI < URI
|
||||
def cache
|
||||
zip_path = super
|
||||
dest = unzipped_path
|
||||
return if dest.exist? && dest.mtime >= zip_path.mtime
|
||||
escaping dest do
|
||||
system "unzip", "-q", "-d", zip_path.dirname.to_path, zip_path.to_path, exception: true
|
||||
end
|
||||
zip_path
|
||||
end
|
||||
|
||||
def clear_cache
|
||||
super
|
||||
unzipped_path.rmtree if unzipped_path.exist?
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def unzipped_path
|
||||
cache_path.sub_ext("")
|
||||
end
|
||||
|
||||
def escaping(path)
|
||||
escaped = Pathname("#{path}.removing")
|
||||
if path.exist?
|
||||
escaped.rmtree if escaped.exist?
|
||||
path.rename escaped
|
||||
end
|
||||
yield
|
||||
ensure
|
||||
if path.exist?
|
||||
escaped.rmtree if escaped.exist?
|
||||
else
|
||||
escaped.rename path if escaped.exist?
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@pre_converted_models = %w[
|
||||
tiny
|
||||
tiny.en
|
||||
@ -203,31 +165,8 @@ module Whisper
|
||||
models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
||||
}
|
||||
|
||||
%w[
|
||||
silero-v5.1.2
|
||||
].each do |name|
|
||||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
@coreml_compiled_models = %w[
|
||||
tiny
|
||||
tiny.en
|
||||
base
|
||||
base.en
|
||||
small
|
||||
small.en
|
||||
medium
|
||||
medium.en
|
||||
large-v1
|
||||
large-v2
|
||||
large-v3
|
||||
large-v3-turbo
|
||||
].each_with_object({}) do |name, models|
|
||||
models[@pre_converted_models[name]] = ZipURI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}-encoder.mlmodelc.zip")
|
||||
end
|
||||
|
||||
class << self
|
||||
attr_reader :pre_converted_models, :coreml_compiled_models
|
||||
attr_reader :pre_converted_models
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -1,58 +0,0 @@
|
||||
module Whisper
|
||||
class Segment
|
||||
SRT_ESCAPES = {
|
||||
"&" => "&",
|
||||
"<" => "<",
|
||||
">" => ">",
|
||||
}
|
||||
SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys)
|
||||
private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE
|
||||
|
||||
def to_srt_cue
|
||||
"#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n"
|
||||
end
|
||||
|
||||
def to_webvtt_cue
|
||||
"#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n"
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def time_to_a(time)
|
||||
sec, decimal_part = time.divmod(1000)
|
||||
min, sec = sec.divmod(60)
|
||||
hour, min = min.divmod(60)
|
||||
[hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
def srt_time(time)
|
||||
"%02d:%02d:%02d,%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def srt_start_time
|
||||
srt_time(start_time)
|
||||
end
|
||||
|
||||
def srt_end_time
|
||||
srt_time(end_time)
|
||||
end
|
||||
|
||||
def srt_text
|
||||
text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES)
|
||||
end
|
||||
|
||||
def webvtt_time(time)
|
||||
"%02d:%02d:%02d.%03d" % time_to_a(time)
|
||||
end
|
||||
|
||||
def webvtt_start_time
|
||||
webvtt_time(start_time)
|
||||
end
|
||||
|
||||
def webvtt_end_time
|
||||
webvtt_time(end_time)
|
||||
end
|
||||
|
||||
alias webvtt_text srt_text
|
||||
end
|
||||
end
|
@ -10,7 +10,6 @@ module Whisper
|
||||
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
|
||||
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
|
||||
|
||||
VERSION: String
|
||||
LOG_LEVEL_NONE: Integer
|
||||
LOG_LEVEL_INFO: Integer
|
||||
LOG_LEVEL_WARN: Integer
|
||||
@ -23,22 +22,21 @@ module Whisper
|
||||
def self.lang_str: (Integer id) -> String
|
||||
def self.lang_str_full: (Integer id) -> String
|
||||
def self.log_set: (log_callback, Object? user_data) -> log_callback
|
||||
def self.system_info_str: () -> String
|
||||
|
||||
class Context
|
||||
def self.new: (String | path | ::URI::HTTP) -> instance
|
||||
def self.new: (path | ::URI::HTTP) -> instance
|
||||
|
||||
# transcribe a single file
|
||||
# can emit to a block results
|
||||
#
|
||||
# params = Whisper::Params.new
|
||||
# params.duration = 60_000
|
||||
# whisper.transcribe "path/to/audio.wav", params do |text|
|
||||
# puts text
|
||||
# end
|
||||
# params = Whisper::Params.new
|
||||
# params.duration = 60_000
|
||||
# whisper.transcribe "path/to/audio.wav", params do |text|
|
||||
# puts text
|
||||
# end
|
||||
#
|
||||
def transcribe: (string, Params, ?n_processors: Integer) -> self
|
||||
| (string, Params, ?n_processors: Integer) { (String) -> void } -> self
|
||||
def transcribe: (string, Params) -> self
|
||||
| (string, Params) { (String) -> void } -> self
|
||||
|
||||
def model_n_vocab: () -> Integer
|
||||
def model_n_audio_ctx: () -> Integer
|
||||
@ -51,16 +49,16 @@ module Whisper
|
||||
|
||||
# Yields each Whisper::Segment:
|
||||
#
|
||||
# whisper.transcribe("path/to/audio.wav", params)
|
||||
# whisper.each_segment do |segment|
|
||||
# puts segment.text
|
||||
# end
|
||||
# whisper.transcribe("path/to/audio.wav", params)
|
||||
# whisper.each_segment do |segment|
|
||||
# puts segment.text
|
||||
# end
|
||||
#
|
||||
# Returns an Enumerator if no block given:
|
||||
#
|
||||
# whisper.transcribe("path/to/audio.wav", params)
|
||||
# enum = whisper.each_segment
|
||||
# enum.to_a # => [#<Whisper::Segment>, ...]
|
||||
# whisper.transcribe("path/to/audio.wav", params)
|
||||
# enum = whisper.each_segment
|
||||
# enum.to_a # => [#<Whisper::Segment>, ...]
|
||||
#
|
||||
def each_segment: { (Segment) -> void } -> void
|
||||
| () -> Enumerator[Segment]
|
||||
@ -75,25 +73,25 @@ module Whisper
|
||||
|
||||
# Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
|
||||
#
|
||||
# full_get_segment_t0(3) # => 1668 (16680 ms)
|
||||
# full_get_segment_t0(3) # => 1668 (16680 ms)
|
||||
#
|
||||
def full_get_segment_t0: (Integer) -> Integer
|
||||
|
||||
# End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
|
||||
#
|
||||
# full_get_segment_t1(3) # => 1668 (16680 ms)
|
||||
# full_get_segment_t1(3) # => 1668 (16680 ms)
|
||||
#
|
||||
def full_get_segment_t1: (Integer) -> Integer
|
||||
|
||||
# Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
|
||||
#
|
||||
# full_get_segment_speacker_turn_next(3) # => true
|
||||
# full_get_segment_speacker_turn_next(3) # => true
|
||||
#
|
||||
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
|
||||
|
||||
# Text of a segment indexed by +segment_index+.
|
||||
#
|
||||
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
|
||||
# full_get_segment_text(3) # => "ask not what your country can do for you, ..."
|
||||
#
|
||||
def full_get_segment_text: (Integer) -> String
|
||||
|
||||
@ -117,9 +115,6 @@ module Whisper
|
||||
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer n_samples) -> self
|
||||
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
|
||||
|
||||
def to_srt: () -> String
|
||||
def to_webvtt: () -> String
|
||||
end
|
||||
|
||||
class Params
|
||||
@ -155,10 +150,7 @@ module Whisper
|
||||
?encoder_begin_callback: encoder_begin_callback,
|
||||
?encoder_begin_callback_user_data: Object,
|
||||
?abort_callback: abort_callback,
|
||||
?abort_callback_user_data: Object,
|
||||
?vad: boolish,
|
||||
?vad_model_path: path | URI,
|
||||
?vad_params: Whisper::VAD::Params
|
||||
?abort_callback_user_data: Object
|
||||
) -> instance
|
||||
|
||||
# params.language = "auto" | "en", etc...
|
||||
@ -286,9 +278,9 @@ module Whisper
|
||||
|
||||
# Sets new segment callback, called for every newly generated text segment.
|
||||
#
|
||||
# params.new_segment_callback = ->(context, _, n_new, user_data) {
|
||||
# # ...
|
||||
# }
|
||||
# params.new_segment_callback = ->(context, _, n_new, user_data) {
|
||||
# # ...
|
||||
# }
|
||||
#
|
||||
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
|
||||
def new_segment_callback: () -> (new_segment_callback | nil)
|
||||
@ -301,9 +293,9 @@ module Whisper
|
||||
|
||||
# Sets progress callback, called on each progress update.
|
||||
#
|
||||
# params.new_segment_callback = ->(context, _, progress, user_data) {
|
||||
# # ...
|
||||
# }
|
||||
# params.new_segment_callback = ->(context, _, progress, user_data) {
|
||||
# # ...
|
||||
# }
|
||||
#
|
||||
# +progress+ is an Integer between 0 and 100.
|
||||
#
|
||||
@ -331,9 +323,9 @@ module Whisper
|
||||
|
||||
# Sets abort callback, called to check if the process should be aborted.
|
||||
#
|
||||
# params.abort_callback = ->(user_data) {
|
||||
# # ...
|
||||
# }
|
||||
# params.abort_callback = ->(user_data) {
|
||||
# # ...
|
||||
# }
|
||||
#
|
||||
#
|
||||
def abort_callback=: (abort_callback) -> abort_callback
|
||||
@ -346,25 +338,11 @@ module Whisper
|
||||
|
||||
def abort_callback_user_data: () -> Object
|
||||
|
||||
# Enable VAD
|
||||
#
|
||||
def vad=: (boolish) -> boolish
|
||||
|
||||
def vad: () -> (true | false)
|
||||
|
||||
# Path to the VAD model
|
||||
def vad_model_path=: (path | URI | nil) -> (path | URI | nil)
|
||||
|
||||
def vad_model_path: () -> (String | nil)
|
||||
|
||||
def vad_params=: (Whisper::VAD::Params) -> Whisper::VAD::Params
|
||||
def vad_params: () -> (Whisper::VAD::Params)
|
||||
|
||||
# Hook called on new segment. Yields each Whisper::Segment.
|
||||
#
|
||||
# whisper.on_new_segment do |segment|
|
||||
# # ...
|
||||
# end
|
||||
# whisper.on_new_segment do |segment|
|
||||
# # ...
|
||||
# end
|
||||
#
|
||||
def on_new_segment: { (Segment) -> void } -> void
|
||||
|
||||
@ -378,20 +356,19 @@ module Whisper
|
||||
|
||||
# Call block to determine whether abort or not. Return +true+ when you want to abort.
|
||||
#
|
||||
# params.abort_on do
|
||||
# if some_condition
|
||||
# true # abort
|
||||
# else
|
||||
# false # continue
|
||||
# end
|
||||
# params.abort_on do
|
||||
# if some_condition
|
||||
# true # abort
|
||||
# else
|
||||
# false # continue
|
||||
# end
|
||||
# end
|
||||
#
|
||||
def abort_on: { (Object user_data) -> boolish } -> void
|
||||
end
|
||||
|
||||
class Model
|
||||
def self.pre_converted_models: () -> Hash[String, Model::URI]
|
||||
def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI]
|
||||
def self.new: () -> instance
|
||||
def n_vocab: () -> Integer
|
||||
def n_audio_ctx: () -> Integer
|
||||
@ -411,22 +388,9 @@ module Whisper
|
||||
def to_path: -> String
|
||||
def clear_cache: -> void
|
||||
end
|
||||
|
||||
class ZipURI < URI
|
||||
def cache: () -> Pathname
|
||||
def clear_cache: () -> void
|
||||
end
|
||||
end
|
||||
|
||||
class Segment
|
||||
type deconstructed_keys = {
|
||||
start_time: (Integer | nil),
|
||||
end_time: (Integer | nil),
|
||||
text: (String | nil),
|
||||
no_speech_prob: (Float | nil),
|
||||
speaker_turn_next: (true | false | nil)
|
||||
}
|
||||
|
||||
# Start time in milliseconds.
|
||||
#
|
||||
def start_time: () -> Integer
|
||||
@ -436,70 +400,10 @@ module Whisper
|
||||
def end_time: () -> Integer
|
||||
|
||||
# Whether the next segment is predicted as a speaker turn.
|
||||
def speaker_turn_next?: () -> (true | false)
|
||||
def speaker_next_turn?: () -> (true | false)
|
||||
|
||||
def text: () -> String
|
||||
def no_speech_prob: () -> Float
|
||||
def to_srt_cue: () -> String
|
||||
def to_webvtt_cue: () -> String
|
||||
|
||||
# Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next
|
||||
#
|
||||
# whisper.each_segment do |segment|
|
||||
# segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
|
||||
#
|
||||
# puts "[#{start_time} --> #{end_time}] #{text} (no speech prob: #{no_speech_prob}#{speaker_turn_next ? ', speaker turns next' : ''})"
|
||||
# end
|
||||
def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next] | nil) -> deconstructed_keys
|
||||
end
|
||||
|
||||
module VAD
|
||||
class Params
|
||||
def self.new: (
|
||||
?threshold: Float,
|
||||
?min_speech_duration_ms: Integer,
|
||||
?min_silence_duration_ms: Integer,
|
||||
?max_speech_duration_s: Float,
|
||||
?speech_pad_ms: Integer,
|
||||
?samples_overlap: Float
|
||||
) -> instance
|
||||
|
||||
# Probability threshold to consider as speech.
|
||||
#
|
||||
def threshold=: (Float) -> Float
|
||||
|
||||
def threshold: () -> Float
|
||||
|
||||
# Min duration for a valid speech segment.
|
||||
#
|
||||
def min_speech_duration_ms=: (Integer) -> Integer
|
||||
|
||||
def min_speech_duration_ms: () -> Integer
|
||||
|
||||
# Min silence duration to consider speech as ended.
|
||||
#
|
||||
def min_silence_duration_ms=: (Integer) -> Integer
|
||||
|
||||
def min_silence_duration_ms: () -> Integer
|
||||
|
||||
# Max duration of a speech segment before forcing a new segment.
|
||||
def max_speech_duration_s=: (Float) -> Float
|
||||
|
||||
def max_speech_duration_s: () -> Float
|
||||
|
||||
# Padding added before and after speech segments.
|
||||
#
|
||||
def speech_pad_ms=: (Integer) -> Integer
|
||||
|
||||
def speech_pad_ms: () -> Integer
|
||||
|
||||
# Overlap in seconds when copying audio samples from speech segment.
|
||||
#
|
||||
def samples_overlap=: (Float) -> Float
|
||||
|
||||
def samples_overlap: () -> Float
|
||||
def ==: (Params) -> (true | false)
|
||||
end
|
||||
end
|
||||
|
||||
class Error < StandardError
|
||||
|
@ -1,51 +0,0 @@
|
||||
require_relative "helper"
|
||||
require 'tempfile'
|
||||
require 'tmpdir'
|
||||
require 'shellwords'
|
||||
|
||||
class TestPackage < TestBase
|
||||
def test_build
|
||||
Tempfile.create do |file|
|
||||
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
|
||||
assert file.size > 0
|
||||
assert_path_exist file.to_path
|
||||
end
|
||||
end
|
||||
|
||||
sub_test_case "Building binary on installation" do
|
||||
def setup
|
||||
system "rake", "build", exception: true
|
||||
end
|
||||
|
||||
def test_install
|
||||
gemspec = Gem::Specification.load("whispercpp.gemspec")
|
||||
Dir.mktmpdir do |dir|
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true
|
||||
assert_installed dir, gemspec.version
|
||||
end
|
||||
end
|
||||
|
||||
def test_install_with_coreml
|
||||
omit_unless RUBY_PLATFORM.match?(/darwin/) do
|
||||
gemspec = Gem::Specification.load("whispercpp.gemspec")
|
||||
Dir.mktmpdir do |dir|
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true
|
||||
assert_installed dir, gemspec.version
|
||||
libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib")
|
||||
assert_nothing_raised do
|
||||
system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true
|
||||
end
|
||||
assert_match(/COREML = 1/, `ruby -I #{libdir.shellescape} -r whisper -e 'puts Whisper.system_info_str'`)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def assert_installed(dir, version)
|
||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
|
||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
|
||||
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
|
||||
end
|
||||
end
|
||||
end
|
@ -1,146 +0,0 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestSegment < TestBase
|
||||
def test_iteration
|
||||
whisper.each_segment do |segment|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
end
|
||||
end
|
||||
|
||||
def test_enumerator
|
||||
enum = whisper.each_segment
|
||||
assert_instance_of Enumerator, enum
|
||||
enum.to_a.each_with_index do |segment, index|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
assert_kind_of Integer, index
|
||||
end
|
||||
end
|
||||
|
||||
def test_start_time
|
||||
i = 0
|
||||
whisper.each_segment do |segment|
|
||||
assert_equal 0, segment.start_time if i == 0
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
|
||||
def test_end_time
|
||||
i = 0
|
||||
whisper.each_segment do |segment|
|
||||
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
|
||||
def test_no_speech_prob
|
||||
no_speech_prob = nil
|
||||
whisper.each_segment do |segment|
|
||||
no_speech_prob = segment.no_speech_prob
|
||||
end
|
||||
assert no_speech_prob > 0.0
|
||||
end
|
||||
|
||||
def test_on_new_segment
|
||||
params = Whisper::Params.new
|
||||
seg = nil
|
||||
index = 0
|
||||
params.on_new_segment do |segment|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
if index == 0
|
||||
seg = segment
|
||||
assert_equal 0, segment.start_time
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
whisper.transcribe(AUDIO, params)
|
||||
assert_equal 0, seg.start_time
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
|
||||
end
|
||||
|
||||
def test_on_new_segment_twice
|
||||
params = Whisper::Params.new
|
||||
seg = nil
|
||||
params.on_new_segment do |segment|
|
||||
seg = segment
|
||||
return
|
||||
end
|
||||
params.on_new_segment do |segment|
|
||||
assert_same seg, segment
|
||||
return
|
||||
end
|
||||
whisper.transcribe(AUDIO, params)
|
||||
end
|
||||
|
||||
def test_transcription_after_segment_retrieved
|
||||
params = Whisper::Params.new
|
||||
segment = whisper.each_segment.first
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
|
||||
|
||||
whisper.transcribe(AUDIO, Whisper::Params.new(offset: 5000))
|
||||
assert_not_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
|
||||
assert_match(/what you can do for your country/i, segment.text)
|
||||
end
|
||||
|
||||
def test_pattern_matching
|
||||
segment = whisper.each_segment.first
|
||||
segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:}
|
||||
|
||||
assert_equal segment.start_time, start_time
|
||||
assert_equal segment.end_time, end_time
|
||||
assert_equal segment.text, text
|
||||
assert_equal segment.no_speech_prob, no_speech_prob
|
||||
assert_equal segment.speaker_turn_next?, speaker_turn_next
|
||||
end
|
||||
|
||||
def test_pattern_matching_partial
|
||||
segment = whisper.each_segment.first
|
||||
segment => {start_time:, end_time:, text:}
|
||||
|
||||
assert_equal segment.start_time, start_time
|
||||
assert_equal segment.end_time, end_time
|
||||
assert_equal segment.text, text
|
||||
end
|
||||
|
||||
def test_deconstruct_keys
|
||||
segment = whisper.each_segment.first
|
||||
expected = {
|
||||
start_time: segment.start_time,
|
||||
end_time: segment.end_time,
|
||||
text: segment.text,
|
||||
no_speech_prob: segment.no_speech_prob,
|
||||
speaker_turn_next: segment.speaker_turn_next?
|
||||
}
|
||||
assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next])
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_non_existent
|
||||
omit "Undefined behavior"
|
||||
|
||||
segment = whisper.each_segment.first
|
||||
|
||||
assert_equal({}, segment.deconstruct_keys([:non_existent]))
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_too_many_keys
|
||||
omit "Undefined behavior"
|
||||
|
||||
segment = whisper.each_segment.first
|
||||
|
||||
assert_equal({}, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next, :extra_key]))
|
||||
end
|
||||
|
||||
def test_deconstruct_keys_includes_non_existent_keys_not_too_many
|
||||
omit "Undefined behavior"
|
||||
|
||||
segment = whisper.each_segment.first
|
||||
|
||||
expected = {
|
||||
start_time: segment.start_time,
|
||||
end_time: segment.end_time,
|
||||
text: segment.text,
|
||||
no_speech_prob: segment.no_speech_prob
|
||||
}
|
||||
assert_equal(expected, segment.deconstruct_keys([:start_time, :end_time, :text, :no_speech_prob, :non_existent]))
|
||||
end
|
||||
end
|
@ -1,19 +0,0 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestVAD < TestBase
|
||||
def setup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
vad_params = Whisper::VAD::Params.new
|
||||
@params = Whisper::Params.new(
|
||||
vad: true,
|
||||
vad_model_path: "silero-v5.1.2",
|
||||
vad_params:
|
||||
)
|
||||
end
|
||||
|
||||
def test_transcribe
|
||||
@whisper.transcribe(TestBase::AUDIO, @params) do |text|
|
||||
assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
|
||||
end
|
||||
end
|
||||
end
|
@ -1,103 +0,0 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestVADParams < TestBase
|
||||
PARAM_NAMES = [
|
||||
:threshold,
|
||||
:min_speech_duration_ms,
|
||||
:min_silence_duration_ms,
|
||||
:max_speech_duration_s,
|
||||
:speech_pad_ms,
|
||||
:samples_overlap
|
||||
]
|
||||
|
||||
def setup
|
||||
@params = Whisper::VAD::Params.new
|
||||
end
|
||||
|
||||
def test_new
|
||||
params = Whisper::VAD::Params.new
|
||||
assert_kind_of Whisper::VAD::Params, params
|
||||
end
|
||||
|
||||
def test_threshold
|
||||
assert_in_delta @params.threshold, 0.5
|
||||
@params.threshold = 0.7
|
||||
assert_in_delta @params.threshold, 0.7
|
||||
end
|
||||
|
||||
def test_min_speech_duration
|
||||
pend
|
||||
end
|
||||
|
||||
def test_min_speech_duration_ms
|
||||
assert_equal 250, @params.min_speech_duration_ms
|
||||
@params.min_speech_duration_ms = 500
|
||||
assert_equal 500, @params.min_speech_duration_ms
|
||||
end
|
||||
|
||||
def test_min_silence_duration_ms
|
||||
assert_equal 100, @params.min_silence_duration_ms
|
||||
@params.min_silence_duration_ms = 200
|
||||
assert_equal 200, @params.min_silence_duration_ms
|
||||
end
|
||||
|
||||
def test_max_speech_duration
|
||||
pend
|
||||
end
|
||||
|
||||
def test_max_speech_duration_s
|
||||
assert @params.max_speech_duration_s >= 10e37 # Defaults to FLT_MAX
|
||||
@params.max_speech_duration_s = 60.0
|
||||
assert_equal 60.0, @params.max_speech_duration_s
|
||||
end
|
||||
|
||||
def test_speech_pad_ms
|
||||
assert_equal 30, @params.speech_pad_ms
|
||||
@params.speech_pad_ms = 50
|
||||
assert_equal 50, @params.speech_pad_ms
|
||||
end
|
||||
|
||||
def test_samples_overlap
|
||||
assert_in_delta @params.samples_overlap, 0.1
|
||||
@params.samples_overlap = 0.5
|
||||
assert_in_delta @params.samples_overlap, 0.5
|
||||
end
|
||||
|
||||
def test_equal
|
||||
assert_equal @params, Whisper::VAD::Params.new
|
||||
end
|
||||
|
||||
def test_new_with_kw_args
|
||||
params = Whisper::VAD::Params.new(threshold: 0.7)
|
||||
assert_in_delta params.threshold, 0.7
|
||||
assert_equal 250, params.min_speech_duration_ms
|
||||
end
|
||||
|
||||
def test_new_with_kw_args_non_existent
|
||||
assert_raise ArgumentError do
|
||||
Whisper::VAD::Params.new(non_existent: "value")
|
||||
end
|
||||
end
|
||||
|
||||
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
|
||||
def test_new_with_kw_args_default_values(param)
|
||||
default_value = @params.send(param)
|
||||
value = default_value + 1
|
||||
params = Whisper::VAD::Params.new(param => value)
|
||||
if Float === value
|
||||
assert_in_delta value, params.send(param)
|
||||
else
|
||||
assert_equal value, params.send(param)
|
||||
end
|
||||
|
||||
PARAM_NAMES.reject {|name| name == param}.each do |name|
|
||||
expected = @params.send(name)
|
||||
actual = params.send(name)
|
||||
if Float === expected
|
||||
assert_in_delta expected, actual
|
||||
else
|
||||
assert_equal expected, actual
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -3,7 +3,7 @@ require "whisper"
|
||||
require_relative "jfk_reader/jfk_reader"
|
||||
|
||||
class TestBase < Test::Unit::TestCase
|
||||
AUDIO = File.join(__dir__, "fixtures", "jfk.wav")
|
||||
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
|
||||
|
||||
class << self
|
||||
def whisper
|
||||
@ -21,4 +21,15 @@ class TestBase < Test::Unit::TestCase
|
||||
def whisper
|
||||
self.class.whisper
|
||||
end
|
||||
|
||||
module BuildOptions
|
||||
load "ext/options.rb", self
|
||||
Options.include self
|
||||
|
||||
def enable_config(name)
|
||||
end
|
||||
|
||||
def arg_config(name)
|
||||
end
|
||||
end
|
||||
end
|
@ -106,13 +106,4 @@ class TestModel < TestBase
|
||||
assert_equal 1, model.ftype
|
||||
assert_equal "base", model.type
|
||||
end
|
||||
|
||||
def test_coreml_model_auto_download
|
||||
uri = Whisper::Model.coreml_compiled_models[Whisper::Model.pre_converted_models["tiny"]]
|
||||
model_path = Pathname(uri.to_path).sub_ext("")
|
||||
model_path.rmtree if model_path.exist?
|
||||
|
||||
uri.cache
|
||||
assert_path_exist model_path
|
||||
end
|
||||
end
|
46
bindings/ruby/tests/test_package.rb
Normal file
46
bindings/ruby/tests/test_package.rb
Normal file
@ -0,0 +1,46 @@
|
||||
require_relative "helper"
|
||||
require 'tempfile'
|
||||
require 'tmpdir'
|
||||
require 'shellwords'
|
||||
|
||||
class TestPackage < TestBase
|
||||
def test_build
|
||||
Tempfile.create do |file|
|
||||
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
|
||||
assert file.size > 0
|
||||
assert_path_exist file.to_path
|
||||
end
|
||||
end
|
||||
|
||||
sub_test_case "Building binary on installation" do
|
||||
def setup
|
||||
system "rake", "build", exception: true
|
||||
end
|
||||
|
||||
def test_install
|
||||
match_data = `rake -Tbuild`.match(/(whispercpp-(.+)\.gem)/)
|
||||
filename = match_data[1]
|
||||
version = match_data[2]
|
||||
Dir.mktmpdir do |dir|
|
||||
system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{filename.shellescape}", exception: true
|
||||
assert_installed dir, version
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def assert_installed(dir, version)
|
||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/lib", "whisper.#{RbConfig::CONFIG["DLEXT"]}")
|
||||
assert_path_exist File.join(dir, "gems/whispercpp-#{version}/LICENSE")
|
||||
assert_path_not_exist File.join(dir, "gems/whispercpp-#{version}/ext/build")
|
||||
end
|
||||
end
|
||||
|
||||
def test_build_options
|
||||
options = BuildOptions::Options.new
|
||||
assert_empty options.missing_options
|
||||
unless ENV["CI"]
|
||||
assert_empty options.extra_options
|
||||
end
|
||||
end
|
||||
end
|
@ -32,9 +32,6 @@ class TestParams < TestBase
|
||||
:progress_callback_user_data,
|
||||
:abort_callback,
|
||||
:abort_callback_user_data,
|
||||
:vad,
|
||||
:vad_model_path,
|
||||
:vad_params,
|
||||
]
|
||||
|
||||
def setup
|
||||
@ -194,50 +191,6 @@ class TestParams < TestBase
|
||||
assert_in_delta 0.2, @params.no_speech_thold
|
||||
end
|
||||
|
||||
def test_vad
|
||||
assert_false @params.vad
|
||||
@params.vad = true
|
||||
assert_true @params.vad
|
||||
end
|
||||
|
||||
def test_vad_model_path
|
||||
assert_nil @params.vad_model_path
|
||||
@params.vad_model_path = "silero-v5.1.2"
|
||||
assert_equal Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path, @params.vad_model_path
|
||||
end
|
||||
|
||||
def test_vad_model_path_with_nil
|
||||
@params.vad_model_path = "silero-v5.1.2"
|
||||
@params.vad_model_path = nil
|
||||
assert_nil @params.vad_model_path
|
||||
end
|
||||
|
||||
def test_vad_model_path_with_invalid
|
||||
assert_raise TypeError do
|
||||
@params.vad_model_path = Object.new
|
||||
end
|
||||
end
|
||||
|
||||
def test_vad_model_path_with_URI_string
|
||||
@params.vad_model_path = "https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin"
|
||||
assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
|
||||
end
|
||||
|
||||
def test_vad_model_path_with_URI
|
||||
@params.vad_model_path = URI("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-silero-v5.1.2.bin")
|
||||
assert_equal @params.vad_model_path, Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
|
||||
end
|
||||
|
||||
def test_vad_params
|
||||
assert_kind_of Whisper::VAD::Params, @params.vad_params
|
||||
default_params = @params.vad_params
|
||||
assert_same default_params, @params.vad_params
|
||||
assert_equal 0.5, default_params.threshold
|
||||
new_params = Whisper::VAD::Params.new
|
||||
@params.vad_params = new_params
|
||||
assert_same new_params, @params.vad_params
|
||||
end
|
||||
|
||||
def test_new_with_kw_args
|
||||
params = Whisper::Params.new(language: "es")
|
||||
assert_equal "es", params.language
|
||||
@ -272,10 +225,6 @@ class TestParams < TestBase
|
||||
proc {}
|
||||
in [/_user_data\Z/, *]
|
||||
Object.new
|
||||
in [:vad_model_path, *]
|
||||
Whisper::Model.pre_converted_models["silero-v5.1.2"].to_path
|
||||
in [:vad_params, *]
|
||||
Whisper::VAD::Params.new
|
||||
end
|
||||
params = Whisper::Params.new(param => value)
|
||||
if Float === value
|
74
bindings/ruby/tests/test_segment.rb
Normal file
74
bindings/ruby/tests/test_segment.rb
Normal file
@ -0,0 +1,74 @@
|
||||
require_relative "helper"
|
||||
|
||||
class TestSegment < TestBase
|
||||
def test_iteration
|
||||
whisper.each_segment do |segment|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
end
|
||||
end
|
||||
|
||||
def test_enumerator
|
||||
enum = whisper.each_segment
|
||||
assert_instance_of Enumerator, enum
|
||||
enum.to_a.each_with_index do |segment, index|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
assert_kind_of Integer, index
|
||||
end
|
||||
end
|
||||
|
||||
def test_start_time
|
||||
i = 0
|
||||
whisper.each_segment do |segment|
|
||||
assert_equal 0, segment.start_time if i == 0
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
|
||||
def test_end_time
|
||||
i = 0
|
||||
whisper.each_segment do |segment|
|
||||
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
|
||||
def test_no_speech_prob
|
||||
no_speech_prob = nil
|
||||
whisper.each_segment do |segment|
|
||||
no_speech_prob = segment.no_speech_prob
|
||||
end
|
||||
assert no_speech_prob > 0.0
|
||||
end
|
||||
|
||||
def test_on_new_segment
|
||||
params = Whisper::Params.new
|
||||
seg = nil
|
||||
index = 0
|
||||
params.on_new_segment do |segment|
|
||||
assert_instance_of Whisper::Segment, segment
|
||||
if index == 0
|
||||
seg = segment
|
||||
assert_equal 0, segment.start_time
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, segment.text)
|
||||
end
|
||||
index += 1
|
||||
end
|
||||
whisper.transcribe(AUDIO, params)
|
||||
assert_equal 0, seg.start_time
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, seg.text)
|
||||
end
|
||||
|
||||
def test_on_new_segment_twice
|
||||
params = Whisper::Params.new
|
||||
seg = nil
|
||||
params.on_new_segment do |segment|
|
||||
seg = segment
|
||||
return
|
||||
end
|
||||
params.on_new_segment do |segment|
|
||||
assert_same seg, segment
|
||||
return
|
||||
end
|
||||
whisper.transcribe(AUDIO, params)
|
||||
end
|
||||
end
|
@ -20,24 +20,6 @@ class TestWhisper < TestBase
|
||||
}
|
||||
end
|
||||
|
||||
def test_transcribe_non_parallel
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
|
||||
@whisper.transcribe(AUDIO, params, n_processors: 1) {|text|
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, text)
|
||||
}
|
||||
end
|
||||
|
||||
def test_transcribe_n_processors
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
|
||||
@whisper.transcribe(AUDIO, params, n_processors: 4) {|text|
|
||||
assert_match(/ask not what your country can do for you[,.] ask what you can do for your country/i, text)
|
||||
}
|
||||
end
|
||||
|
||||
sub_test_case "After transcription" do
|
||||
def test_full_n_segments
|
||||
assert_equal 1, whisper.full_n_segments
|
||||
@ -112,14 +94,6 @@ class TestWhisper < TestBase
|
||||
end
|
||||
end
|
||||
|
||||
def test_system_info_str
|
||||
assert_match(/\AWHISPER : COREML = \d | OPENVINO = \d |/, Whisper.system_info_str)
|
||||
end
|
||||
|
||||
def test_version
|
||||
assert_kind_of String, Whisper::VERSION
|
||||
end
|
||||
|
||||
def test_log_set
|
||||
user_data = Object.new
|
||||
logs = []
|
||||
@ -249,48 +223,4 @@ class TestWhisper < TestBase
|
||||
assert_match(/for your country/i, text)
|
||||
end
|
||||
end
|
||||
|
||||
def test_to_srt
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
whisper.transcribe AUDIO, @params
|
||||
|
||||
lines = whisper.to_srt.lines
|
||||
assert_match(/\A\d+\n/, lines[0])
|
||||
assert_match(/\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}\n/, lines[1])
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, lines[2])
|
||||
end
|
||||
|
||||
def test_to_webvtt
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
whisper.transcribe AUDIO, @params
|
||||
|
||||
lines = whisper.to_webvtt.lines
|
||||
assert_equal "WEBVTT\n", lines[0]
|
||||
assert_equal "\n", lines[1]
|
||||
assert_match(/\A\d+\n/, lines[2])
|
||||
assert_match(/\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}\n/, lines[3])
|
||||
assert_match(/ask not what your country can do for you, ask what you can do for your country/, lines[4])
|
||||
end
|
||||
|
||||
sub_test_case "Format needs escape" do
|
||||
def setup
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
@whisper.transcribe AUDIO, Whisper::Params.new
|
||||
segment = @whisper.each_segment.first
|
||||
segment.define_singleton_method :text do
|
||||
"& so my fellow Americans --> ask not what your country can do for you <-- ask what you can do for your country."
|
||||
end
|
||||
@whisper.define_singleton_method :each_segment do
|
||||
Enumerator.new(3) {|yielder| 3.times {yielder << segment}}
|
||||
end
|
||||
end
|
||||
|
||||
def test_to_srt_escape
|
||||
assert_equal "& so my fellow Americans --> ask not what your country can do for you <-- ask what you can do for your country.\n", @whisper.to_srt.lines[2]
|
||||
end
|
||||
|
||||
def test_to_webvtt_escape
|
||||
assert_equal "& so my fellow Americans --> ask not what your country can do for you <-- ask what you can do for your country.\n", @whisper.to_webvtt.lines[4]
|
||||
end
|
||||
end
|
||||
end
|
@ -3,7 +3,8 @@ require_relative "extsources"
|
||||
Gem::Specification.new do |s|
|
||||
s.name = "whispercpp"
|
||||
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
|
||||
s.version = '1.3.3'
|
||||
s.version = '1.3.2'
|
||||
s.date = '2025-05-01'
|
||||
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
|
||||
s.email = 'todd.fisher@gmail.com'
|
||||
s.extra_rdoc_files = ['LICENSE', 'README.md']
|
||||
@ -20,7 +21,7 @@ Gem::Specification.new do |s|
|
||||
}
|
||||
|
||||
s.summary = %q{Ruby whisper.cpp bindings}
|
||||
s.test_files = s.files.select {|file| file.start_with? "test/"}
|
||||
s.test_files = s.files.select {|file| file.start_with? "tests/"}
|
||||
|
||||
s.extensions << 'ext/extconf.rb'
|
||||
s.required_ruby_version = '>= 3.1.0'
|
||||
|
@ -105,7 +105,6 @@ else()
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(server)
|
||||
add_subdirectory(quantize)
|
||||
add_subdirectory(vad-speech-segments)
|
||||
if (WHISPER_SDL2)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(command)
|
||||
|
@ -1,10 +1,8 @@
|
||||
# whisper.cpp Node.js addon
|
||||
# addon
|
||||
|
||||
This is an addon demo that can **perform whisper model reasoning in `node` and `electron` environments**, based on [cmake-js](https://github.com/cmake-js/cmake-js).
|
||||
It can be used as a reference for using the whisper.cpp project in other node projects.
|
||||
|
||||
This addon now supports **Voice Activity Detection (VAD)** for improved transcription performance.
|
||||
|
||||
## Install
|
||||
|
||||
```shell
|
||||
@ -28,88 +26,12 @@ For Electron addon and cmake-js options, you can see [cmake-js](https://github.c
|
||||
|
||||
## Run
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```shell
|
||||
cd examples/addon.node
|
||||
|
||||
node index.js --language='language' --model='model-path' --fname_inp='file-path'
|
||||
```
|
||||
|
||||
### VAD (Voice Activity Detection) Usage
|
||||
Because this is a simple Demo, only the above parameters are set in the node environment.
|
||||
|
||||
Run the VAD example with performance comparison:
|
||||
|
||||
```shell
|
||||
node vad-example.js
|
||||
```
|
||||
|
||||
## Voice Activity Detection (VAD) Support
|
||||
|
||||
VAD can significantly improve transcription performance by only processing speech segments, which is especially beneficial for audio files with long periods of silence.
|
||||
|
||||
### VAD Model Setup
|
||||
|
||||
Before using VAD, download a VAD model:
|
||||
|
||||
```shell
|
||||
# From the whisper.cpp root directory
|
||||
./models/download-vad-model.sh silero-v5.1.2
|
||||
```
|
||||
|
||||
### VAD Parameters
|
||||
|
||||
All VAD parameters are optional and have sensible defaults:
|
||||
|
||||
- `vad`: Enable VAD (default: false)
|
||||
- `vad_model`: Path to VAD model file (required when VAD enabled)
|
||||
- `vad_threshold`: Speech detection threshold 0.0-1.0 (default: 0.5)
|
||||
- `vad_min_speech_duration_ms`: Min speech duration in ms (default: 250)
|
||||
- `vad_min_silence_duration_ms`: Min silence duration in ms (default: 100)
|
||||
- `vad_max_speech_duration_s`: Max speech duration in seconds (default: FLT_MAX)
|
||||
- `vad_speech_pad_ms`: Speech padding in ms (default: 30)
|
||||
- `vad_samples_overlap`: Sample overlap 0.0-1.0 (default: 0.1)
|
||||
|
||||
### JavaScript API Example
|
||||
|
||||
```javascript
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(__dirname, "../../build/Release/addon.node"));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
// With VAD enabled
|
||||
const vadParams = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
||||
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
|
||||
vad: true,
|
||||
vad_model: path.join(__dirname, "../../models/ggml-silero-v5.1.2.bin"),
|
||||
vad_threshold: 0.5,
|
||||
progress_callback: (progress) => console.log(`Progress: ${progress}%`)
|
||||
};
|
||||
|
||||
whisperAsync(vadParams).then(result => console.log(result));
|
||||
```
|
||||
|
||||
## Supported Parameters
|
||||
|
||||
Both traditional whisper.cpp parameters and new VAD parameters are supported:
|
||||
|
||||
- `language`: Language code (e.g., "en", "es", "fr")
|
||||
- `model`: Path to whisper model file
|
||||
- `fname_inp`: Path to input audio file
|
||||
- `use_gpu`: Enable GPU acceleration (default: true)
|
||||
- `flash_attn`: Enable flash attention (default: false)
|
||||
- `no_prints`: Disable console output (default: false)
|
||||
- `no_timestamps`: Disable timestamps (default: false)
|
||||
- `detect_language`: Auto-detect language (default: false)
|
||||
- `audio_ctx`: Audio context size (default: 0)
|
||||
- `max_len`: Maximum segment length (default: 0)
|
||||
- `max_context`: Maximum context size (default: -1)
|
||||
- `prompt`: Initial prompt for decoder
|
||||
- `comma_in_time`: Use comma in timestamps (default: true)
|
||||
- `print_progress`: Print progress info (default: false)
|
||||
- `progress_callback`: Progress callback function
|
||||
- VAD parameters (see above section)
|
||||
Other parameters can also be specified in the node environment.
|
||||
|
@ -1,133 +1,37 @@
|
||||
const { join } = require('path');
|
||||
const { whisper } = require('../../../build/Release/addon.node');
|
||||
const { promisify } = require('util');
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(
|
||||
__dirname,
|
||||
"../../../build/Release/addon.node"
|
||||
));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
const commonParams = {
|
||||
language: 'en',
|
||||
model: join(__dirname, '../../../models/ggml-base.en.bin'),
|
||||
fname_inp: join(__dirname, '../../../samples/jfk.wav'),
|
||||
const whisperParamsMock = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
|
||||
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
|
||||
use_gpu: true,
|
||||
flash_attn: false,
|
||||
no_prints: true,
|
||||
comma_in_time: false,
|
||||
translate: true,
|
||||
no_timestamps: false,
|
||||
detect_language: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0
|
||||
max_len: 0,
|
||||
prompt: "",
|
||||
print_progress: false,
|
||||
progress_callback: (progress) => {
|
||||
console.log(`Progress: ${progress}`);
|
||||
},
|
||||
max_context: -1
|
||||
};
|
||||
|
||||
describe('Whisper.cpp Node.js addon with VAD support', () => {
|
||||
test('Basic whisper transcription without VAD', async () => {
|
||||
const params = {
|
||||
...commonParams,
|
||||
vad: false
|
||||
};
|
||||
describe("Run whisper.node", () => {
|
||||
test("it should receive a non-empty value", async () => {
|
||||
let result = await whisperAsync(whisperParamsMock);
|
||||
|
||||
const result = await whisperAsync(params);
|
||||
|
||||
expect(typeof result).toBe('object');
|
||||
expect(Array.isArray(result.transcription)).toBe(true);
|
||||
expect(result.transcription.length).toBeGreaterThan(0);
|
||||
|
||||
// Check that we got some transcription text
|
||||
const text = result.transcription.map(segment => segment[2]).join(' ');
|
||||
expect(text.length).toBeGreaterThan(0);
|
||||
expect(text.toLowerCase()).toContain('ask not');
|
||||
}, 30000);
|
||||
|
||||
test('VAD parameters validation', async () => {
|
||||
// Test with invalid VAD model - should return empty transcription
|
||||
const invalidParams = {
|
||||
...commonParams,
|
||||
vad: true,
|
||||
vad_model: 'non-existent-model.bin',
|
||||
vad_threshold: 0.5
|
||||
};
|
||||
|
||||
// This should handle the error gracefully and return empty transcription
|
||||
const result = await whisperAsync(invalidParams);
|
||||
expect(typeof result).toBe('object');
|
||||
expect(Array.isArray(result.transcription)).toBe(true);
|
||||
// When VAD model doesn't exist, it should return empty transcription
|
||||
expect(result.transcription.length).toBe(0);
|
||||
}, 10000);
|
||||
|
||||
test('VAD parameter parsing', async () => {
|
||||
// Test that VAD parameters are properly parsed (even if VAD model doesn't exist)
|
||||
const vadParams = {
|
||||
...commonParams,
|
||||
vad: false, // Disabled so no model required
|
||||
vad_threshold: 0.7,
|
||||
vad_min_speech_duration_ms: 300,
|
||||
vad_min_silence_duration_ms: 150,
|
||||
vad_max_speech_duration_s: 45.0,
|
||||
vad_speech_pad_ms: 50,
|
||||
vad_samples_overlap: 0.15
|
||||
};
|
||||
|
||||
const result = await whisperAsync(vadParams);
|
||||
|
||||
expect(typeof result).toBe('object');
|
||||
expect(Array.isArray(result.transcription)).toBe(true);
|
||||
}, 30000);
|
||||
|
||||
test('Progress callback with VAD disabled', async () => {
|
||||
let progressCalled = false;
|
||||
let lastProgress = 0;
|
||||
|
||||
const params = {
|
||||
...commonParams,
|
||||
vad: false,
|
||||
progress_callback: (progress) => {
|
||||
progressCalled = true;
|
||||
lastProgress = progress;
|
||||
expect(progress).toBeGreaterThanOrEqual(0);
|
||||
expect(progress).toBeLessThanOrEqual(100);
|
||||
}
|
||||
};
|
||||
|
||||
const result = await whisperAsync(params);
|
||||
|
||||
expect(progressCalled).toBe(true);
|
||||
expect(lastProgress).toBe(100);
|
||||
expect(typeof result).toBe('object');
|
||||
}, 30000);
|
||||
|
||||
test('Language detection without VAD', async () => {
|
||||
const params = {
|
||||
...commonParams,
|
||||
vad: false,
|
||||
detect_language: true,
|
||||
language: 'auto'
|
||||
};
|
||||
|
||||
const result = await whisperAsync(params);
|
||||
|
||||
expect(typeof result).toBe('object');
|
||||
expect(typeof result.language).toBe('string');
|
||||
expect(result.language.length).toBeGreaterThan(0);
|
||||
}, 30000);
|
||||
|
||||
test('Basic transcription with all VAD parameters set', async () => {
|
||||
// Test with VAD disabled but all parameters set to ensure no crashes
|
||||
const params = {
|
||||
...commonParams,
|
||||
vad: false, // Disabled so it works without VAD model
|
||||
vad_model: '', // Empty model path
|
||||
vad_threshold: 0.6,
|
||||
vad_min_speech_duration_ms: 200,
|
||||
vad_min_silence_duration_ms: 80,
|
||||
vad_max_speech_duration_s: 25.0,
|
||||
vad_speech_pad_ms: 40,
|
||||
vad_samples_overlap: 0.08
|
||||
};
|
||||
|
||||
const result = await whisperAsync(params);
|
||||
|
||||
expect(typeof result).toBe('object');
|
||||
expect(Array.isArray(result.transcription)).toBe(true);
|
||||
expect(result.transcription.length).toBeGreaterThan(0);
|
||||
}, 30000);
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
}, 10000);
|
||||
});
|
||||
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cfloat>
|
||||
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -39,7 +38,6 @@ struct whisper_params {
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool no_prints = false;
|
||||
bool detect_language= false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
bool comma_in_time = true;
|
||||
@ -52,16 +50,6 @@ struct whisper_params {
|
||||
std::vector<std::string> fname_out = {};
|
||||
|
||||
std::vector<float> pcmf32 = {}; // mono-channel F32 PCM
|
||||
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
bool vad = false;
|
||||
std::string vad_model = "";
|
||||
float vad_threshold = 0.5f;
|
||||
int vad_min_speech_duration_ms = 250;
|
||||
int vad_min_silence_duration_ms = 100;
|
||||
float vad_max_speech_duration_s = FLT_MAX;
|
||||
int vad_speech_pad_ms = 30;
|
||||
float vad_samples_overlap = 0.1f;
|
||||
};
|
||||
|
||||
struct whisper_print_user_data {
|
||||
@ -94,7 +82,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps && !params.no_prints) {
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
@ -125,14 +113,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
|
||||
// colorful print bug
|
||||
//
|
||||
if (!params.no_prints) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if ((!params.no_timestamps || params.diarize) && !params.no_prints) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
@ -142,11 +128,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
|
||||
void cb_log_disable(enum ggml_log_level, const char *, void *) {}
|
||||
|
||||
struct whisper_result {
|
||||
std::vector<std::vector<std::string>> segments;
|
||||
std::string language;
|
||||
};
|
||||
|
||||
class ProgressWorker : public Napi::AsyncWorker {
|
||||
public:
|
||||
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
|
||||
@ -177,27 +158,15 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
|
||||
void OnOK() override {
|
||||
Napi::HandleScope scope(Env());
|
||||
|
||||
if (params.detect_language) {
|
||||
Napi::Object resultObj = Napi::Object::New(Env());
|
||||
resultObj.Set("language", Napi::String::New(Env(), result.language));
|
||||
Callback().Call({Env().Null(), resultObj});
|
||||
}
|
||||
|
||||
Napi::Object returnObj = Napi::Object::New(Env());
|
||||
if (!result.language.empty()) {
|
||||
returnObj.Set("language", Napi::String::New(Env(), result.language));
|
||||
}
|
||||
Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size());
|
||||
for (uint64_t i = 0; i < result.segments.size(); ++i) {
|
||||
Napi::Object res = Napi::Array::New(Env(), result.size());
|
||||
for (uint64_t i = 0; i < result.size(); ++i) {
|
||||
Napi::Object tmp = Napi::Array::New(Env(), 3);
|
||||
for (uint64_t j = 0; j < 3; ++j) {
|
||||
tmp[j] = Napi::String::New(Env(), result.segments[i][j]);
|
||||
tmp[j] = Napi::String::New(Env(), result[i][j]);
|
||||
}
|
||||
transcriptionArray[i] = tmp;
|
||||
}
|
||||
returnObj.Set("transcription", transcriptionArray);
|
||||
Callback().Call({Env().Null(), returnObj});
|
||||
res[i] = tmp;
|
||||
}
|
||||
Callback().Call({Env().Null(), res});
|
||||
}
|
||||
|
||||
// Progress callback function - using thread-safe function
|
||||
@ -214,12 +183,12 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
|
||||
private:
|
||||
whisper_params params;
|
||||
whisper_result result;
|
||||
std::vector<std::vector<std::string>> result;
|
||||
Napi::Env env;
|
||||
Napi::ThreadSafeFunction tsfn;
|
||||
|
||||
// Custom run function with progress callback support
|
||||
int run_with_progress(whisper_params ¶ms, whisper_result & result) {
|
||||
int run_with_progress(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||
if (params.no_prints) {
|
||||
whisper_log_set(cb_log_disable, NULL);
|
||||
}
|
||||
@ -308,8 +277,7 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.detect_language ? "auto" : params.language.c_str();
|
||||
wparams.detect_language = params.detect_language;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
@ -344,38 +312,34 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
};
|
||||
wparams.progress_callback_user_data = this;
|
||||
|
||||
// Set VAD parameters
|
||||
wparams.vad = params.vad;
|
||||
wparams.vad_model_path = params.vad_model.c_str();
|
||||
// Abort mechanism example
|
||||
{
|
||||
static bool is_aborted = false; // Note: this should be atomic to avoid data races
|
||||
|
||||
wparams.vad_params.threshold = params.vad_threshold;
|
||||
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
|
||||
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
|
||||
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
|
||||
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
|
||||
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "failed to process audio\n");
|
||||
return 10;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.detect_language || params.language == "auto") {
|
||||
result.language = whisper_lang_str(whisper_full_lang_id(ctx));
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
result.segments.resize(n_segments);
|
||||
|
||||
result.resize(n_segments);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time));
|
||||
result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time));
|
||||
result.segments[i].emplace_back(text);
|
||||
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
|
||||
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
|
||||
result[i].emplace_back(text);
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
@ -396,46 +360,13 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
std::string language = whisper_params.Get("language").As<Napi::String>();
|
||||
std::string model = whisper_params.Get("model").As<Napi::String>();
|
||||
std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
|
||||
|
||||
bool use_gpu = true;
|
||||
if (whisper_params.Has("use_gpu") && whisper_params.Get("use_gpu").IsBoolean()) {
|
||||
use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
bool flash_attn = false;
|
||||
if (whisper_params.Has("flash_attn") && whisper_params.Get("flash_attn").IsBoolean()) {
|
||||
flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
bool no_prints = false;
|
||||
if (whisper_params.Has("no_prints") && whisper_params.Get("no_prints").IsBoolean()) {
|
||||
no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
bool no_timestamps = false;
|
||||
if (whisper_params.Has("no_timestamps") && whisper_params.Get("no_timestamps").IsBoolean()) {
|
||||
no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
bool detect_language = false;
|
||||
if (whisper_params.Has("detect_language") && whisper_params.Get("detect_language").IsBoolean()) {
|
||||
detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
int32_t audio_ctx = 0;
|
||||
if (whisper_params.Has("audio_ctx") && whisper_params.Get("audio_ctx").IsNumber()) {
|
||||
audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
|
||||
}
|
||||
|
||||
bool comma_in_time = true;
|
||||
if (whisper_params.Has("comma_in_time") && whisper_params.Get("comma_in_time").IsBoolean()) {
|
||||
comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
int32_t max_len = 0;
|
||||
if (whisper_params.Has("max_len") && whisper_params.Get("max_len").IsNumber()) {
|
||||
max_len = whisper_params.Get("max_len").As<Napi::Number>();
|
||||
}
|
||||
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
|
||||
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
|
||||
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
|
||||
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
|
||||
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
|
||||
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
|
||||
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
|
||||
|
||||
// Add support for max_context
|
||||
int32_t max_context = -1;
|
||||
@ -451,7 +382,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
|
||||
// Add support for print_progress
|
||||
bool print_progress = false;
|
||||
if (whisper_params.Has("print_progress") && whisper_params.Get("print_progress").IsBoolean()) {
|
||||
if (whisper_params.Has("print_progress")) {
|
||||
print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
|
||||
}
|
||||
// Add support for progress_callback
|
||||
@ -460,47 +391,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
|
||||
}
|
||||
|
||||
// Add support for VAD parameters
|
||||
bool vad = false;
|
||||
if (whisper_params.Has("vad") && whisper_params.Get("vad").IsBoolean()) {
|
||||
vad = whisper_params.Get("vad").As<Napi::Boolean>();
|
||||
}
|
||||
|
||||
std::string vad_model = "";
|
||||
if (whisper_params.Has("vad_model") && whisper_params.Get("vad_model").IsString()) {
|
||||
vad_model = whisper_params.Get("vad_model").As<Napi::String>();
|
||||
}
|
||||
|
||||
float vad_threshold = 0.5f;
|
||||
if (whisper_params.Has("vad_threshold") && whisper_params.Get("vad_threshold").IsNumber()) {
|
||||
vad_threshold = whisper_params.Get("vad_threshold").As<Napi::Number>();
|
||||
}
|
||||
|
||||
int vad_min_speech_duration_ms = 250;
|
||||
if (whisper_params.Has("vad_min_speech_duration_ms") && whisper_params.Get("vad_min_speech_duration_ms").IsNumber()) {
|
||||
vad_min_speech_duration_ms = whisper_params.Get("vad_min_speech_duration_ms").As<Napi::Number>();
|
||||
}
|
||||
|
||||
int vad_min_silence_duration_ms = 100;
|
||||
if (whisper_params.Has("vad_min_silence_duration_ms") && whisper_params.Get("vad_min_silence_duration_ms").IsNumber()) {
|
||||
vad_min_silence_duration_ms = whisper_params.Get("vad_min_silence_duration_ms").As<Napi::Number>();
|
||||
}
|
||||
|
||||
float vad_max_speech_duration_s = FLT_MAX;
|
||||
if (whisper_params.Has("vad_max_speech_duration_s") && whisper_params.Get("vad_max_speech_duration_s").IsNumber()) {
|
||||
vad_max_speech_duration_s = whisper_params.Get("vad_max_speech_duration_s").As<Napi::Number>();
|
||||
}
|
||||
|
||||
int vad_speech_pad_ms = 30;
|
||||
if (whisper_params.Has("vad_speech_pad_ms") && whisper_params.Get("vad_speech_pad_ms").IsNumber()) {
|
||||
vad_speech_pad_ms = whisper_params.Get("vad_speech_pad_ms").As<Napi::Number>();
|
||||
}
|
||||
|
||||
float vad_samples_overlap = 0.1f;
|
||||
if (whisper_params.Has("vad_samples_overlap") && whisper_params.Get("vad_samples_overlap").IsNumber()) {
|
||||
vad_samples_overlap = whisper_params.Get("vad_samples_overlap").As<Napi::Number>();
|
||||
}
|
||||
|
||||
Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
|
||||
std::vector<float> pcmf32_vec;
|
||||
if (pcmf32Value.IsTypedArray()) {
|
||||
@ -526,17 +416,6 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
params.max_context = max_context;
|
||||
params.print_progress = print_progress;
|
||||
params.prompt = prompt;
|
||||
params.detect_language = detect_language;
|
||||
|
||||
// Set VAD parameters
|
||||
params.vad = vad;
|
||||
params.vad_model = vad_model;
|
||||
params.vad_threshold = vad_threshold;
|
||||
params.vad_min_speech_duration_ms = vad_min_speech_duration_ms;
|
||||
params.vad_min_silence_duration_ms = vad_min_silence_duration_ms;
|
||||
params.vad_max_speech_duration_s = vad_max_speech_duration_s;
|
||||
params.vad_speech_pad_ms = vad_speech_pad_ms;
|
||||
params.vad_samples_overlap = vad_samples_overlap;
|
||||
|
||||
Napi::Function callback = info[1].As<Napi::Function>();
|
||||
// Create a new Worker class with progress callback support
|
||||
|
@ -17,7 +17,6 @@ const whisperParams = {
|
||||
comma_in_time: false,
|
||||
translate: true,
|
||||
no_timestamps: false,
|
||||
detect_language: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0,
|
||||
progress_callback: (progress) => {
|
||||
@ -32,8 +31,6 @@ const params = Object.fromEntries(
|
||||
const [key, value] = item.slice(2).split("=");
|
||||
if (key === "audio_ctx") {
|
||||
whisperParams[key] = parseInt(value);
|
||||
} else if (key === "detect_language") {
|
||||
whisperParams[key] = value === "true";
|
||||
} else {
|
||||
whisperParams[key] = value;
|
||||
}
|
||||
|
@ -1,132 +0,0 @@
|
||||
const path = require("path");
|
||||
const { whisper } = require(path.join(
|
||||
__dirname,
|
||||
"../../build/Release/addon.node"
|
||||
));
|
||||
const { promisify } = require("util");
|
||||
|
||||
const whisperAsync = promisify(whisper);
|
||||
|
||||
// Example with VAD enabled
|
||||
const vadParams = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
||||
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
|
||||
use_gpu: true,
|
||||
flash_attn: false,
|
||||
no_prints: false,
|
||||
comma_in_time: true,
|
||||
translate: false,
|
||||
no_timestamps: false,
|
||||
detect_language: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0,
|
||||
// VAD parameters
|
||||
vad: true,
|
||||
vad_model: path.join(__dirname, "../../models/ggml-silero-v5.1.2.bin"), // You need to download this model
|
||||
vad_threshold: 0.5,
|
||||
vad_min_speech_duration_ms: 250,
|
||||
vad_min_silence_duration_ms: 100,
|
||||
vad_max_speech_duration_s: 30.0,
|
||||
vad_speech_pad_ms: 30,
|
||||
vad_samples_overlap: 0.1,
|
||||
progress_callback: (progress) => {
|
||||
console.log(`VAD Transcription progress: ${progress}%`);
|
||||
}
|
||||
};
|
||||
|
||||
// Example without VAD (traditional approach)
|
||||
const traditionalParams = {
|
||||
language: "en",
|
||||
model: path.join(__dirname, "../../models/ggml-base.en.bin"),
|
||||
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
|
||||
use_gpu: true,
|
||||
flash_attn: false,
|
||||
no_prints: false,
|
||||
comma_in_time: true,
|
||||
translate: false,
|
||||
no_timestamps: false,
|
||||
detect_language: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0,
|
||||
vad: false, // Explicitly disable VAD
|
||||
progress_callback: (progress) => {
|
||||
console.log(`Traditional transcription progress: ${progress}%`);
|
||||
}
|
||||
};
|
||||
|
||||
async function runVADExample() {
|
||||
try {
|
||||
console.log("=== Whisper.cpp Node.js VAD Example ===\n");
|
||||
|
||||
// Check if VAD model exists
|
||||
const fs = require('fs');
|
||||
if (!fs.existsSync(vadParams.vad_model)) {
|
||||
console.log("⚠️ VAD model not found. Please download the VAD model first:");
|
||||
console.log(" ./models/download-vad-model.sh silero-v5.1.2");
|
||||
console.log(" Or run: python models/convert-silero-vad-to-ggml.py");
|
||||
console.log("\n Falling back to traditional transcription without VAD...\n");
|
||||
|
||||
// Run without VAD
|
||||
console.log("🎵 Running traditional transcription...");
|
||||
const traditionalResult = await whisperAsync(traditionalParams);
|
||||
console.log("\n📝 Traditional transcription result:");
|
||||
console.log(traditionalResult);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("🎵 Running transcription with VAD enabled...");
|
||||
console.log("VAD Parameters:");
|
||||
console.log(` - Threshold: ${vadParams.vad_threshold}`);
|
||||
console.log(` - Min speech duration: ${vadParams.vad_min_speech_duration_ms}ms`);
|
||||
console.log(` - Min silence duration: ${vadParams.vad_min_silence_duration_ms}ms`);
|
||||
console.log(` - Max speech duration: ${vadParams.vad_max_speech_duration_s}s`);
|
||||
console.log(` - Speech padding: ${vadParams.vad_speech_pad_ms}ms`);
|
||||
console.log(` - Samples overlap: ${vadParams.vad_samples_overlap}\n`);
|
||||
|
||||
const startTime = Date.now();
|
||||
const vadResult = await whisperAsync(vadParams);
|
||||
const vadDuration = Date.now() - startTime;
|
||||
|
||||
console.log("\n✅ VAD transcription completed!");
|
||||
console.log(`⏱️ Processing time: ${vadDuration}ms`);
|
||||
console.log("\n📝 VAD transcription result:");
|
||||
console.log(vadResult);
|
||||
|
||||
// Compare with traditional approach
|
||||
console.log("\n🔄 Running traditional transcription for comparison...");
|
||||
const traditionalStartTime = Date.now();
|
||||
const traditionalResult = await whisperAsync(traditionalParams);
|
||||
const traditionalDuration = Date.now() - traditionalStartTime;
|
||||
|
||||
console.log("\n✅ Traditional transcription completed!");
|
||||
console.log(`⏱️ Processing time: ${traditionalDuration}ms`);
|
||||
console.log("\n📝 Traditional transcription result:");
|
||||
console.log(traditionalResult);
|
||||
|
||||
// Performance comparison
|
||||
console.log("\n📊 Performance Comparison:");
|
||||
console.log(`VAD: ${vadDuration}ms`);
|
||||
console.log(`Traditional: ${traditionalDuration}ms`);
|
||||
const speedup = traditionalDuration / vadDuration;
|
||||
if (speedup > 1) {
|
||||
console.log(`🚀 VAD is ${speedup.toFixed(2)}x faster!`);
|
||||
} else {
|
||||
console.log(`ℹ️ Traditional approach was ${(1/speedup).toFixed(2)}x faster in this case.`);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error("❌ Error during transcription:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Run the example
|
||||
if (require.main === module) {
|
||||
runVADExample();
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
runVADExample,
|
||||
vadParams,
|
||||
traditionalParams
|
||||
};
|
@ -35,7 +35,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
-s INITIAL_MEMORY=2000MB \
|
||||
-s TOTAL_MEMORY=2000MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
|
@ -28,10 +28,5 @@ to the server's HTTP path:
|
||||
```
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/bench.wasm/* /path/to/html/
|
||||
cp bin/libbench.js /path/to/html/
|
||||
cp bin/libbench.worker.js /path/to/html/
|
||||
```
|
||||
|
||||
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
|
||||
> longer generated and the worker is embedded in the main JS file. So the worker
|
||||
> file will not be geneated for versions later than `3.1.58`.
|
||||
|
@ -66,12 +66,13 @@ static int whisper_bench_full(const whisper_params & params) {
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", params.n_threads, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 2;
|
||||
@ -155,8 +156,6 @@ static int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <cfloat>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#ifndef NOMINMAX
|
||||
@ -70,7 +69,6 @@ struct whisper_params {
|
||||
bool no_prints = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_confidence= false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool log_score = false;
|
||||
@ -99,16 +97,6 @@ struct whisper_params {
|
||||
std::vector<std::string> fname_out = {};
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
bool vad = false;
|
||||
std::string vad_model = "";
|
||||
float vad_threshold = 0.5f;
|
||||
int vad_min_speech_duration_ms = 250;
|
||||
int vad_min_silence_duration_ms = 100;
|
||||
float vad_max_speech_duration_s = FLT_MAX;
|
||||
int vad_speech_pad_ms = 30;
|
||||
float vad_samples_overlap = 0.1f;
|
||||
};
|
||||
|
||||
static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -180,7 +168,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
||||
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if ( arg == "--print-confidence"){ params.print_confidence= true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(ARGV_NEXT); }
|
||||
@ -198,15 +185,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
||||
else if ( arg == "--grammar") { params.grammar = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar-rule") { params.grammar_rule = ARGV_NEXT; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
|
||||
// Voice Activity Detection (VAD)
|
||||
else if ( arg == "--vad") { params.vad = true; }
|
||||
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
|
||||
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(ARGV_NEXT); }
|
||||
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(ARGV_NEXT); }
|
||||
else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(ARGV_NEXT); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -259,7 +237,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
|
||||
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " --print-confidence [%-7s] print confidence\n", params.print_confidence ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
@ -277,18 +254,6 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
|
||||
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
|
||||
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
|
||||
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
|
||||
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
|
||||
fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
|
||||
fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
|
||||
std::string("FLT_MAX").c_str() :
|
||||
std::to_string(params.vad_max_speech_duration_s).c_str());
|
||||
fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
|
||||
fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -389,26 +354,6 @@ static void whisper_print_segment_callback(struct whisper_context * ctx, struct
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else if (params.print_confidence) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
int style_idx = 2; // High confidence - dim
|
||||
if (p < 0.33) {
|
||||
style_idx = 0; // Low confidence - inverse (highlighted)
|
||||
} else if (p < 0.66) {
|
||||
style_idx = 1; // Medium confidence - underlined
|
||||
}
|
||||
printf("%s%s%s%s", speaker.c_str(), k_styles[style_idx].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
@ -909,8 +854,6 @@ static void output_lrc(struct whisper_context * ctx, std::ofstream & fout, const
|
||||
static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
#if defined(_WIN32)
|
||||
// Set the console output code page to UTF-8, while command line arguments
|
||||
// are still encoded in the system's code page. In this way, we can print
|
||||
@ -990,6 +933,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
@ -1137,11 +1081,6 @@ int main(int argc, char ** argv) {
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
if (params.print_colors) {
|
||||
fprintf(stderr, "%s: color scheme: red (low confidence), yellow (medium), green (high confidence)\n", __func__);
|
||||
} else if (params.print_confidence) {
|
||||
fprintf(stderr, "%s: confidence: highlighted (low confidence), underlined (medium), dim (high confidence)\n", __func__);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -1192,16 +1131,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.suppress_nst = params.suppress_nst;
|
||||
|
||||
wparams.vad = params.vad;
|
||||
wparams.vad_model_path = params.vad_model.c_str();
|
||||
|
||||
wparams.vad_params.threshold = params.vad_threshold;
|
||||
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
|
||||
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
|
||||
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
|
||||
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
|
||||
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
|
@ -36,7 +36,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
|
@ -28,10 +28,5 @@ To run the example in a different server, you need to copy the following files
|
||||
to the server's HTTP path:
|
||||
```
|
||||
cp bin/command.wasm/* /path/to/html/
|
||||
cp bin/libcommand.js /path/to/html/
|
||||
cp bin/libcommand.worker.js /path/to/html/
|
||||
```
|
||||
|
||||
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
|
||||
> longer generated and the worker is embedded in the main JS file. So the worker
|
||||
> file will not be geneated for versions later than `3.1.58`.
|
||||
|
@ -251,7 +251,7 @@ static std::vector<std::string> get_words(const std::string &txt) {
|
||||
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
static int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms, std::ofstream &fout) {
|
||||
static int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
@ -444,16 +444,12 @@ static int process_command_list(struct whisper_context * ctx, audio_async &audio
|
||||
|
||||
const float prob = probs_id[0].first;
|
||||
const int index = probs_id[0].second;
|
||||
const char * best_command = allowed_commands[index].c_str();
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", best_command, "\033[0m", prob,
|
||||
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
if (fout.is_open()) {
|
||||
fout << best_command << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -466,7 +462,7 @@ static int process_command_list(struct whisper_context * ctx, audio_async &audio
|
||||
|
||||
// always-prompt mode
|
||||
// transcribe the voice into text after valid prompt
|
||||
static int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params, std::ofstream & fout) {
|
||||
static int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
@ -532,9 +528,6 @@ static int always_prompt_transcription(struct whisper_context * ctx, audio_async
|
||||
|
||||
if ((sim > 0.7f) && (command.size() > 0)) {
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
if (fout.is_open()) {
|
||||
fout << command << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
@ -549,7 +542,7 @@ static int always_prompt_transcription(struct whisper_context * ctx, audio_async
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
static int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params, std::ofstream & fout) {
|
||||
static int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
@ -669,10 +662,8 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
|
||||
} else {
|
||||
// cut the prompt from the decoded text
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
if (fout.is_open()) {
|
||||
fout << command << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
@ -687,8 +678,6 @@ static int process_general_transcription(struct whisper_context * ctx, audio_asy
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
@ -709,10 +698,6 @@ int main(int argc, char ** argv) {
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
@ -772,22 +757,13 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
std::ofstream fout;
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout.open(params.fname_out);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params, fout);
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params, fout);
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params, fout);
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,20 +112,13 @@ bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std:
|
||||
}
|
||||
|
||||
if (stereo) {
|
||||
std::vector<float> stereo_data = pcmf32;
|
||||
pcmf32.resize(frame_count);
|
||||
|
||||
for (uint64_t i = 0; i < frame_count; i++) {
|
||||
pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]);
|
||||
}
|
||||
|
||||
pcmf32s.resize(2);
|
||||
pcmf32s[0].resize(frame_count);
|
||||
pcmf32s[1].resize(frame_count);
|
||||
for (uint64_t i = 0; i < frame_count; i++) {
|
||||
pcmf32s[0][i] = stereo_data[2*i];
|
||||
pcmf32s[1][i] = stereo_data[2*i + 1];
|
||||
}
|
||||
pcmf32s.resize(2);
|
||||
pcmf32s[0].resize(frame_count);
|
||||
pcmf32s[1].resize(frame_count);
|
||||
for (uint64_t i = 0; i < frame_count; i++) {
|
||||
pcmf32s[0][i] = pcmf32[2*i];
|
||||
pcmf32s[1][i] = pcmf32[2*i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
ma_decoder_uninit(&decoder);
|
||||
|
@ -283,7 +283,7 @@ static std::string set_xterm256_foreground(int r, int g, int b) {
|
||||
}
|
||||
|
||||
// Lowest is red, middle is yellow, highest is green. Color scheme from
|
||||
// Paul Tol; it is colorblind friendly https://sronpersonalpages.nl/~pault
|
||||
// Paul Tol; it is colorblind friendly https://personal.sron.nl/~pault/
|
||||
const std::vector<std::string> k_colors = {
|
||||
set_xterm256_foreground(220, 5, 12),
|
||||
set_xterm256_foreground(232, 96, 28),
|
||||
@ -294,26 +294,6 @@ const std::vector<std::string> k_colors = {
|
||||
set_xterm256_foreground( 78, 178, 101),
|
||||
};
|
||||
|
||||
// ANSI formatting codes
|
||||
static std::string set_inverse() {
|
||||
return "\033[7m";
|
||||
}
|
||||
|
||||
static std::string set_underline() {
|
||||
return "\033[4m";
|
||||
}
|
||||
|
||||
static std::string set_dim() {
|
||||
return "\033[2m";
|
||||
}
|
||||
|
||||
// Style scheme for different confidence levels
|
||||
const std::vector<std::string> k_styles = {
|
||||
set_inverse(), // Low confidence - inverse (highlighted)
|
||||
set_underline(), // Medium confidence - underlined
|
||||
set_dim(), // High confidence - dim
|
||||
};
|
||||
|
||||
//
|
||||
// Other utils
|
||||
//
|
||||
|
@ -424,8 +424,6 @@ static void process_loop(struct whisper_context * ctx, audio_async &audio, const
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
whisper_params params;
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "common-ggml.h"
|
||||
@ -177,8 +176,6 @@ static bool whisper_model_quantize(const std::string & fname_inp, const std::str
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
if (argc != 4) {
|
||||
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
|
||||
ggml_print_ftypes(stderr);
|
||||
|
@ -1,6 +1,3 @@
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
set(TARGET whisper-server)
|
||||
add_executable(${TARGET} server.cpp httplib.h)
|
||||
|
||||
|
@ -23,7 +23,6 @@ options:
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-ac N, --audio-ctx N [0 ] audio context size (0 - all)
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
@ -42,28 +41,9 @@ options:
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
|
||||
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
|
||||
--port PORT, [8080 ] Port number for the server
|
||||
--public PATH, [examples/server/public] Path to the public folder
|
||||
--request-path PATH, [ ] Request path for all requests
|
||||
--inference-path PATH, [/inference] Inference path for all requests
|
||||
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
|
||||
-sns, --suppress-nst [false ] suppress non-speech tokens
|
||||
-nth N, --no-speech-thold N [0.60 ] no speech threshold
|
||||
-nc, --no-context [false ] do not use previous audio context
|
||||
-ng, --no-gpu [false ] do not use gpu
|
||||
-fa, --flash-attn [false ] flash attention
|
||||
|
||||
Voice Activity Detection (VAD) options:
|
||||
--vad [false ] enable Voice Activity Detection (VAD)
|
||||
-vm FNAME, --vad-model FNAME [ ] VAD model path
|
||||
-vt N, --vad-threshold N [0.50 ] VAD threshold for speech recognition
|
||||
-vspd N, --vad-min-speech-duration-ms N [250 ] VAD min speech duration (0.0-1.0)
|
||||
-vsd N, --vad-min-silence-duration-ms N [100 ] VAD min silence duration (to split segments)
|
||||
-vmsd N, --vad-max-speech-duration-s N [FLT_MAX] VAD max speech duration (auto-split longer)
|
||||
-vp N, --vad-speech-pad-ms N [30 ] VAD speech padding (extend segments)
|
||||
-vo N, --vad-samples-overlap N [0.10 ] VAD samples overlap (seconds between segments)
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
@ -87,35 +67,3 @@ curl 127.0.0.1:8080/load \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F model="<path-to-model-file>"
|
||||
```
|
||||
|
||||
## Load testing with k6
|
||||
|
||||
> **Note:** Install [k6](https://k6.io/docs/get-started/installation/) before running the benchmark script.
|
||||
|
||||
You can benchmark the Whisper server using the provided bench.js script with [k6](https://k6.io/). This script sends concurrent multipart requests to the /inference endpoint and is fully configurable via environment variables.
|
||||
|
||||
**Example usage:**
|
||||
|
||||
```
|
||||
k6 run bench.js \
|
||||
--env FILE_PATH=/absolute/path/to/samples/jfk.wav \
|
||||
--env BASE_URL=http://127.0.0.1:8080 \
|
||||
--env ENDPOINT=/inference \
|
||||
--env CONCURRENCY=4 \
|
||||
--env TEMPERATURE=0.0 \
|
||||
--env TEMPERATURE_INC=0.2 \
|
||||
--env RESPONSE_FORMAT=json
|
||||
```
|
||||
|
||||
**Environment variables:**
|
||||
- `FILE_PATH`: Path to the audio file to send (must be absolute or relative to the k6 working directory)
|
||||
- `BASE_URL`: Server base URL (default: `http://127.0.0.1:8080`)
|
||||
- `ENDPOINT`: API endpoint (default: `/inference`)
|
||||
- `CONCURRENCY`: Number of concurrent requests (default: 4)
|
||||
- `TEMPERATURE`: Decoding temperature (default: 0.0)
|
||||
- `TEMPERATURE_INC`: Temperature increment (default: 0.2)
|
||||
- `RESPONSE_FORMAT`: Response format (default: `json`)
|
||||
|
||||
**Note:**
|
||||
- The server must be running and accessible at the specified `BASE_URL` and `ENDPOINT`.
|
||||
- The script is located in the same directory as this README: `bench.js`.
|
||||
|
@ -1,29 +0,0 @@
|
||||
import http from 'k6/http'
|
||||
import { check } from 'k6'
|
||||
|
||||
export let options = {
|
||||
vus: parseInt(__ENV.CONCURRENCY) || 4,
|
||||
iterations: parseInt(__ENV.CONCURRENCY) || 4,
|
||||
}
|
||||
|
||||
const filePath = __ENV.FILE_PATH
|
||||
const baseURL = __ENV.BASE_URL || 'http://127.0.0.1:8080'
|
||||
const endpoint = __ENV.ENDPOINT || '/inference'
|
||||
const temperature = __ENV.TEMPERATURE || '0.0'
|
||||
const temperatureInc = __ENV.TEMPERATURE_INC || '0.2'
|
||||
const responseFormat = __ENV.RESPONSE_FORMAT || 'json'
|
||||
|
||||
// Read the file ONCE at init time
|
||||
const fileBin = open(filePath, 'b')
|
||||
|
||||
export default function () {
|
||||
const payload = {
|
||||
file: http.file(fileBin, filePath),
|
||||
temperature: temperature,
|
||||
temperature_inc: temperatureInc,
|
||||
response_format: responseFormat,
|
||||
}
|
||||
|
||||
const res = http.post(`${baseURL}${endpoint}`, payload)
|
||||
check(res, { 'status is 200': r => r.status === 200 })
|
||||
}
|
@ -5,7 +5,6 @@
|
||||
#include "httplib.h"
|
||||
#include "json.hpp"
|
||||
|
||||
#include <cfloat>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
@ -14,23 +13,10 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <csignal>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
#include <cstdlib>
|
||||
#if defined (_WIN32)
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
using namespace httplib;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum server_state {
|
||||
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
||||
SERVER_STATE_READY, // Server is ready and model is loaded
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
// output formats
|
||||
@ -40,20 +26,6 @@ const std::string srt_format = "srt";
|
||||
const std::string vjson_format = "verbose_json";
|
||||
const std::string vtt_format = "vtt";
|
||||
|
||||
std::function<void(int)> shutdown_handler;
|
||||
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
|
||||
|
||||
inline void signal_handler(int signal) {
|
||||
if (is_terminating.test_and_set()) {
|
||||
// in case it hangs, we can force terminate the server by hitting Ctrl+C twice
|
||||
// this is for better developer experience, we can remove when the server is stable enough
|
||||
fprintf(stderr, "Received second interrupt, terminating immediately.\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
struct server_params
|
||||
{
|
||||
std::string hostname = "127.0.0.1";
|
||||
@ -118,16 +90,6 @@ struct whisper_params {
|
||||
std::string openvino_encode_device = "CPU";
|
||||
|
||||
std::string dtw = "";
|
||||
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
bool vad = false;
|
||||
std::string vad_model = "";
|
||||
float vad_threshold = 0.5f;
|
||||
int vad_min_speech_duration_ms = 250;
|
||||
int vad_min_silence_duration_ms = 100;
|
||||
float vad_max_speech_duration_s = FLT_MAX;
|
||||
int vad_speech_pad_ms = 30;
|
||||
float vad_samples_overlap = 0.1f;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) {
|
||||
@ -177,19 +139,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
|
||||
fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
||||
// Voice Activity Detection (VAD) parameters
|
||||
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
|
||||
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
|
||||
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
|
||||
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
|
||||
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
|
||||
fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms);
|
||||
fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ?
|
||||
std::string("FLT_MAX").c_str() :
|
||||
std::to_string(params.vad_max_speech_duration_s).c_str());
|
||||
fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms);
|
||||
fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -245,16 +194,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if ( arg == "--request-path") { sparams.request_path = argv[++i]; }
|
||||
else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; }
|
||||
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
|
||||
|
||||
// Voice Activity Detection (VAD)
|
||||
else if ( arg == "--vad") { params.vad = true; }
|
||||
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; }
|
||||
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); }
|
||||
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); }
|
||||
else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
@ -571,41 +510,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
|
||||
{
|
||||
params.no_context = parse_str_to_bool(req.get_file_value("no_context").content);
|
||||
}
|
||||
if (req.has_file("vad"))
|
||||
{
|
||||
params.vad = parse_str_to_bool(req.get_file_value("vad").content);
|
||||
}
|
||||
if (req.has_file("vad_threshold"))
|
||||
{
|
||||
params.vad_threshold = std::stof(req.get_file_value("vad_threshold").content);
|
||||
}
|
||||
if (req.has_file("vad_min_speech_duration_ms"))
|
||||
{
|
||||
params.vad_min_speech_duration_ms = std::stof(req.get_file_value("vad_min_speech_duration_ms").content);
|
||||
}
|
||||
if (req.has_file("vad_min_silence_duration_ms"))
|
||||
{
|
||||
params.vad_min_silence_duration_ms = std::stof(req.get_file_value("vad_min_silence_duration_ms").content);
|
||||
}
|
||||
if (req.has_file("vad_max_speech_duration_s"))
|
||||
{
|
||||
params.vad_max_speech_duration_s = std::stof(req.get_file_value("vad_max_speech_duration_s").content);
|
||||
}
|
||||
if (req.has_file("vad_speech_pad_ms"))
|
||||
{
|
||||
params.vad_speech_pad_ms = std::stoi(req.get_file_value("vad_speech_pad_ms").content);
|
||||
}
|
||||
if (req.has_file("vad_samples_overlap"))
|
||||
{
|
||||
params.vad_samples_overlap = std::stof(req.get_file_value("vad_samples_overlap").content);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
whisper_params params;
|
||||
server_params sparams;
|
||||
|
||||
@ -674,19 +583,13 @@ int main(int argc, char ** argv) {
|
||||
if (params.dtw == "large.v3") {
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
|
||||
}
|
||||
if (params.dtw == "large.v3.turbo") {
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3_TURBO;
|
||||
}
|
||||
|
||||
|
||||
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
|
||||
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
|
||||
return 3;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<httplib::Server> svr = std::make_unique<httplib::Server>();
|
||||
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
if (ctx == nullptr) {
|
||||
@ -696,10 +599,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
state.store(SERVER_STATE_READY);
|
||||
|
||||
|
||||
svr->set_default_headers({{"Server", "whisper.cpp"},
|
||||
Server svr;
|
||||
svr.set_default_headers({{"Server", "whisper.cpp"},
|
||||
{"Access-Control-Allow-Origin", "*"},
|
||||
{"Access-Control-Allow-Headers", "content-type, authorization"}});
|
||||
|
||||
@ -778,15 +680,15 @@ int main(int argc, char ** argv) {
|
||||
whisper_params default_params = params;
|
||||
|
||||
// this is only called if no index.html is found in the public --path
|
||||
svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){
|
||||
svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){
|
||||
res.set_content(default_content, "text/html");
|
||||
return false;
|
||||
});
|
||||
|
||||
svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
|
||||
svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){
|
||||
});
|
||||
|
||||
svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
|
||||
svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){
|
||||
// acquire whisper model mutex lock
|
||||
std::lock_guard<std::mutex> lock(whisper_mutex);
|
||||
|
||||
@ -924,16 +826,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.suppress_nst = params.suppress_nst;
|
||||
|
||||
wparams.vad = params.vad;
|
||||
wparams.vad_model_path = params.vad_model.c_str();
|
||||
|
||||
wparams.vad_params.threshold = params.vad_threshold;
|
||||
wparams.vad_params.min_speech_duration_ms = params.vad_min_speech_duration_ms;
|
||||
wparams.vad_params.min_silence_duration_ms = params.vad_min_silence_duration_ms;
|
||||
wparams.vad_params.max_speech_duration_s = params.vad_max_speech_duration_s;
|
||||
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
|
||||
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
// this callback is called on each new segment
|
||||
@ -1102,9 +994,8 @@ int main(int argc, char ** argv) {
|
||||
// reset params to their defaults
|
||||
params = default_params;
|
||||
});
|
||||
svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
|
||||
svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){
|
||||
std::lock_guard<std::mutex> lock(whisper_mutex);
|
||||
state.store(SERVER_STATE_LOADING_MODEL);
|
||||
if (!req.has_file("model"))
|
||||
{
|
||||
fprintf(stderr, "error: no 'model' field in the request\n");
|
||||
@ -1136,25 +1027,18 @@ int main(int argc, char ** argv) {
|
||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
|
||||
state.store(SERVER_STATE_READY);
|
||||
const std::string success = "Load was successful!";
|
||||
res.set_content(success, "application/text");
|
||||
|
||||
// check if the model is in the file system
|
||||
});
|
||||
|
||||
svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){
|
||||
server_state current_state = state.load();
|
||||
if (current_state == SERVER_STATE_READY) {
|
||||
const std::string health_response = "{\"status\":\"ok\"}";
|
||||
res.set_content(health_response, "application/json");
|
||||
} else {
|
||||
res.set_content("{\"status\":\"loading model\"}", "application/json");
|
||||
res.status = 503;
|
||||
}
|
||||
svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){
|
||||
const std::string health_response = "{\"status\":\"ok\"}";
|
||||
res.set_content(health_response, "application/json");
|
||||
});
|
||||
|
||||
svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
|
||||
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
|
||||
const char fmt[] = "500 Internal Server Error\n%s";
|
||||
char buf[BUFSIZ];
|
||||
try {
|
||||
@ -1168,7 +1052,7 @@ int main(int argc, char ** argv) {
|
||||
res.status = 500;
|
||||
});
|
||||
|
||||
svr->set_error_handler([](const Request &req, Response &res) {
|
||||
svr.set_error_handler([](const Request &req, Response &res) {
|
||||
if (res.status == 400) {
|
||||
res.set_content("Invalid request", "text/plain");
|
||||
} else if (res.status != 500) {
|
||||
@ -1178,10 +1062,10 @@ int main(int argc, char ** argv) {
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
svr->set_read_timeout(sparams.read_timeout);
|
||||
svr->set_write_timeout(sparams.write_timeout);
|
||||
svr.set_read_timeout(sparams.read_timeout);
|
||||
svr.set_write_timeout(sparams.write_timeout);
|
||||
|
||||
if (!svr->bind_to_port(sparams.hostname, sparams.port))
|
||||
if (!svr.bind_to_port(sparams.hostname, sparams.port))
|
||||
{
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
|
||||
sparams.hostname.c_str(), sparams.port);
|
||||
@ -1189,50 +1073,18 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// Set the base directory for serving static files
|
||||
svr->set_base_dir(sparams.public_path);
|
||||
svr.set_base_dir(sparams.public_path);
|
||||
|
||||
// to make it ctrl+clickable:
|
||||
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||
|
||||
shutdown_handler = [&](int signal) {
|
||||
printf("\nCaught signal %d, shutting down gracefully...\n", signal);
|
||||
if (svr) {
|
||||
svr->stop();
|
||||
}
|
||||
};
|
||||
if (!svr.listen_after_bind())
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
// clean up function, to be called before exit
|
||||
auto clean_up = [&]() {
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
};
|
||||
|
||||
std::thread t([&] {
|
||||
if (!svr->listen_after_bind()) {
|
||||
fprintf(stderr, "error: server listen failed\n");
|
||||
}
|
||||
});
|
||||
|
||||
svr->wait_until_ready();
|
||||
|
||||
t.join();
|
||||
|
||||
|
||||
clean_up();
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
|
||||
-s INITIAL_MEMORY=1024MB \
|
||||
-s TOTAL_MEMORY=1024MB \
|
||||
-s FORCE_FILESYSTEM=1 \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap', 'HEAPU8']\" \
|
||||
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
|
||||
${EXTRA_FLAGS} \
|
||||
")
|
||||
|
||||
|
@ -26,10 +26,5 @@ to the server's HTTP path:
|
||||
```
|
||||
# copy the produced page to your HTTP path
|
||||
cp bin/stream.wasm/* /path/to/html/
|
||||
cp bin/libstream.js /path/to/html/
|
||||
cp bin/libstream.worker.js /path/to/html/
|
||||
```
|
||||
|
||||
> 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
|
||||
> longer generated and the worker is embedded in the main JS file. So the worker
|
||||
> file will not be geneated for versions later than `3.1.58`.
|
||||
|
@ -116,8 +116,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_backend_load_all();
|
||||
|
||||
whisper_params params;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params) == false) {
|
||||
@ -163,10 +161,6 @@ int main(int argc, char ** argv) {
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old;
|
||||
|
@ -16,14 +16,10 @@ if (WHISPER_SDL2)
|
||||
llama-hparams.cpp
|
||||
llama-impl.cpp
|
||||
llama-io.cpp
|
||||
llama-kv-cache-unified.cpp
|
||||
llama-kv-cache-unified-iswa.cpp
|
||||
llama-memory-recurrent.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-memory.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model-saver.cpp
|
||||
llama-model.cpp
|
||||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
|
@ -253,9 +253,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||
{
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
@ -294,9 +291,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
|
||||
break;
|
||||
|
@ -20,7 +20,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_NEO_BERT, "neo-bert" },
|
||||
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
||||
{ LLM_ARCH_BLOOM, "bloom" },
|
||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||
@ -42,7 +41,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@ -74,9 +72,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -149,7 +144,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@ -180,8 +174,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
|
||||
{ LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" },
|
||||
|
||||
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
@ -200,13 +192,13 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
||||
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
||||
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
||||
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
@ -250,24 +242,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARCEE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
@ -474,7 +448,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
@ -519,21 +492,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NEO_BERT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_CLS, "cls" },
|
||||
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
{
|
||||
@ -934,42 +892,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA3N,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
|
||||
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
|
||||
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
|
||||
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
|
||||
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
|
||||
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
|
||||
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
|
||||
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
|
||||
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
|
||||
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
|
||||
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
|
||||
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
|
||||
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
|
||||
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
@ -1559,9 +1481,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -1631,51 +1550,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DOTS1,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
}
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -1804,23 +1678,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
// altup / laurel (gemma 3n)
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// this tensor is loaded for T5, but never used
|
||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
||||
@ -1844,14 +1701,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
std::string LLM_KV::operator()(llm_kv kv) const {
|
||||
std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
|
||||
if (suffix != nullptr) {
|
||||
name += ".";
|
||||
name += suffix;
|
||||
}
|
||||
|
||||
return name;
|
||||
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
|
||||
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
}
|
||||
|
||||
std::string LLM_TN_IMPL::str() const {
|
||||
@ -1890,25 +1741,3 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||
return LLM_TENSOR_INFOS.at(tensor);
|
||||
}
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||
// the place to identify them
|
||||
switch (arch) {
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -24,7 +24,6 @@ enum llm_arch {
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_NEO_BERT,
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
@ -46,7 +45,6 @@ enum llm_arch {
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_GEMMA3N,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
@ -78,9 +76,6 @@ enum llm_arch {
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -153,7 +148,6 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_LAYER_INDICES,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
@ -196,13 +190,13 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_MASK_ID,
|
||||
LLM_KV_TOKENIZER_ADD_BOS,
|
||||
LLM_KV_TOKENIZER_ADD_EOS,
|
||||
LLM_KV_TOKENIZER_ADD_SEP,
|
||||
LLM_KV_TOKENIZER_ADD_PREFIX,
|
||||
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
||||
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
||||
LLM_KV_TOKENIZER_HF_JSON,
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
@ -219,8 +213,6 @@ enum llm_kv {
|
||||
LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
|
||||
LLM_KV_CONVNEXT_BLOCK_COUNT,
|
||||
|
||||
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
||||
|
||||
// deprecated:
|
||||
LLM_KV_TOKENIZER_PREFIX_ID,
|
||||
LLM_KV_TOKENIZER_SUFFIX_ID,
|
||||
@ -271,22 +263,6 @@ enum llm_tensor {
|
||||
LLM_TENSOR_LAYER_OUT_NORM,
|
||||
LLM_TENSOR_POST_ATTN_NORM,
|
||||
LLM_TENSOR_POST_MLP_NORM,
|
||||
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
|
||||
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
|
||||
LLM_TENSOR_ALTUP_PROJ, // gemma3n
|
||||
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
|
||||
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
|
||||
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
|
||||
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
|
||||
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
|
||||
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
|
||||
LLM_TENSOR_LAUREL_L, // gemma3n
|
||||
LLM_TENSOR_LAUREL_R, // gemma3n
|
||||
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
|
||||
LLM_TENSOR_SSM_IN,
|
||||
LLM_TENSOR_SSM_CONV1D,
|
||||
LLM_TENSOR_SSM_X,
|
||||
@ -459,6 +435,3 @@ const char * llm_arch_name(llm_arch arch);
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,146 +2,87 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <bitset>
|
||||
#include <unordered_map>
|
||||
|
||||
// keep this struct lightweight
|
||||
// it points to data in `llama_batch_allocr`
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
struct llama_ubatch {
|
||||
bool equal_seqs;
|
||||
// TODO: whole_seqs for embeddings?
|
||||
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_seq_tokens; // tokens per sequence set
|
||||
uint32_t n_seqs; // sequence sets in the ubatch
|
||||
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_seq_tokens; // tokens per sequence
|
||||
uint32_t n_seqs;
|
||||
|
||||
// seq_id_unq: unique sequence ids in the ubatch
|
||||
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
||||
// used for extracting sequence pooled embeddings
|
||||
|
||||
// // size | idx | val
|
||||
llama_token * token; // [n_tokens] | i | id, token
|
||||
float * embd; // [n_embd, n_tokens] | i | embd
|
||||
llama_pos * pos; // [n_tokens] | i | pos
|
||||
int32_t * n_seq_id; // [n_tokens] | i | -
|
||||
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
||||
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
||||
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
||||
int8_t * output; // [n_tokens] | i | -
|
||||
llama_token * token; // [n_tokens]
|
||||
float * embd; // [n_embd, n_tokens]
|
||||
llama_pos * pos; // [n_tokens]
|
||||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int8_t * output; // [n_tokens]
|
||||
};
|
||||
|
||||
// a helper for sanitizing, fulfilling and splitting a batch
|
||||
class llama_batch_allocr {
|
||||
public:
|
||||
llama_batch_allocr(uint32_t n_pos_per_embd);
|
||||
struct llama_sbatch_seq {
|
||||
int32_t n_seq_id;
|
||||
|
||||
// sanitize and auto-gen missing data in the input batch
|
||||
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
||||
bool init(
|
||||
const llama_batch & batch_inp,
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
bool output_all);
|
||||
llama_seq_id * seq_id;
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
size_t offset;
|
||||
size_t length;
|
||||
};
|
||||
|
||||
uint32_t get_n_tokens() const;
|
||||
uint32_t get_n_outputs() const;
|
||||
// sequence-length-aware batch splitting
|
||||
struct llama_sbatch {
|
||||
// tokens left in this batch
|
||||
size_t n_tokens;
|
||||
|
||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||
std::vector<int32_t> & get_out_ids();
|
||||
size_t n_embd;
|
||||
|
||||
// min/max positions of each sequence in the current ubatch
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||
|
||||
// call once before splitting the batch to reset the internal state
|
||||
void split_reset();
|
||||
// sorted indices into the batch
|
||||
std::vector<int64_t> ids;
|
||||
// batch indices of the output
|
||||
std::vector<int64_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
// simple split, unknown number of sequence sets of unequal lengths
|
||||
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||
const llama_batch * batch = nullptr;
|
||||
|
||||
// make ubatches of equal-length sequences sets
|
||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
||||
// buffers for the ubatch
|
||||
std::vector<llama_token> ubatch_token;
|
||||
std::vector<float> ubatch_embd;
|
||||
std::vector<llama_pos> ubatch_pos;
|
||||
std::vector<int32_t> ubatch_n_seq_id;
|
||||
std::vector<llama_seq_id *> ubatch_seq_id;
|
||||
std::vector<int8_t> ubatch_output;
|
||||
|
||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
||||
|
||||
// a helper method for creating a well-defined ubatch of tokens
|
||||
// TODO: support embeddings if needed in the future
|
||||
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
|
||||
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
|
||||
|
||||
private:
|
||||
void clear();
|
||||
// simple split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_simple(size_t n_ubatch);
|
||||
|
||||
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
|
||||
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
|
||||
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
|
||||
// make batches of equal-length sequences
|
||||
llama_ubatch split_equal(size_t n_ubatch);
|
||||
|
||||
// for debugging, start with LLAMA_BATCH_DEBUG=2
|
||||
void ubatch_print(const llama_ubatch & ubatch, int debug);
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
llama_batch batch;
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// only for debugging purposes
|
||||
const llama_vocab * vocab;
|
||||
|
||||
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
||||
const uint32_t n_pos_per_embd;
|
||||
|
||||
uint32_t n_embd;
|
||||
uint32_t n_outputs;
|
||||
// temporary allocate memory for the input batch if needed
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
std::vector<int8_t> logits;
|
||||
|
||||
using pos_set_t = std::set<llama_pos>;
|
||||
using seq_cpl_t = std::vector<bool>;
|
||||
|
||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
|
||||
using idx_vec_t = std::vector<int32_t>;
|
||||
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
||||
|
||||
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
|
||||
|
||||
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
|
||||
|
||||
// batch indices of the output
|
||||
std::vector<int32_t> out_ids;
|
||||
|
||||
// used[i] indicates if token i has already been used in a previous ubatch
|
||||
std::vector<bool> used;
|
||||
|
||||
// llama_ubatch points to this data:
|
||||
struct ubatch {
|
||||
std::vector<llama_token> token;
|
||||
std::vector<float> embd;
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
};
|
||||
|
||||
// current splitting state:
|
||||
std::vector<ubatch> ubatches;
|
||||
|
||||
int debug;
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
};
|
||||
|
@ -35,7 +35,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
||||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
@ -183,8 +182,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING;
|
||||
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
|
||||
return LLM_CHAT_TEMPLATE_LLAMA4;
|
||||
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@ -205,20 +202,19 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
||||
// Official mistral 'v7' template
|
||||
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
||||
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
|
||||
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
|
||||
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
||||
} else if (role == "user") {
|
||||
ss << "[INST]" << trailing_space << content << "[/INST]";
|
||||
} else {
|
||||
ss << trailing_space << content << "</s>";
|
||||
ss << "[INST] " << content << "[/INST]";
|
||||
}
|
||||
else {
|
||||
ss << " " << content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
||||
@ -333,7 +329,7 @@ int32_t llm_chat_apply_template(
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
||||
system_prompt += trim(message->content);
|
||||
system_prompt = trim(message->content);
|
||||
continue;
|
||||
}
|
||||
// in gemma, "assistant" is "model"
|
||||
@ -355,7 +351,7 @@ int32_t llm_chat_apply_template(
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// there is no system message support, we will merge it with user prompt
|
||||
system_prompt += message->content;
|
||||
system_prompt = message->content;
|
||||
continue;
|
||||
} else if (role == "user") {
|
||||
ss << "Human: ";
|
||||
@ -451,16 +447,8 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
ss << "[gMASK]" << "<sop>";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|" << role << "|>" << "\n" << message->content;
|
||||
@ -528,17 +516,12 @@ int32_t llm_chat_apply_template(
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
||||
// this template requires the model to have "\n\n" as EOT token
|
||||
for (size_t i = 0; i < chat.size(); i++) {
|
||||
std::string role(chat[i]->role);
|
||||
if (role == "system") {
|
||||
ss << "System: " << trim(chat[i]->content) << "\n\n";
|
||||
} else if (role == "user") {
|
||||
ss << "User: " << trim(chat[i]->content) << "\n\n";
|
||||
if (i == chat.size() - 1) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else if (role == "assistant") {
|
||||
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "user") {
|
||||
ss << "User: " << message->content << "\n\nAssistant:";
|
||||
} else {
|
||||
ss << message->content << "\n\n";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
||||
@ -650,21 +633,6 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
|
||||
// dots.llm1.inst (DOTS1)
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "<|system|>" << message->content << "<|endofsystem|>";
|
||||
} else if (role == "user") {
|
||||
ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
|
||||
} else {
|
||||
ss << "<|response|>" << message->content << "<|endofresponse|>";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|response|>";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
@ -14,7 +14,6 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
@ -43,7 +42,6 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_BAILING,
|
||||
LLM_CHAT_TEMPLATE_LLAMA4,
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,25 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-adapter.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_model;
|
||||
class llama_batch_allocr;
|
||||
struct llama_kv_cache;
|
||||
|
||||
class llama_io_read_i;
|
||||
class llama_io_write_i;
|
||||
|
||||
struct llama_memory_i;
|
||||
struct llama_memory_context_i;
|
||||
|
||||
struct llama_context {
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
llama_context(
|
||||
@ -30,12 +27,7 @@ struct llama_context {
|
||||
|
||||
void synchronize();
|
||||
|
||||
const llama_model & get_model() const;
|
||||
const llama_cparams & get_cparams() const;
|
||||
|
||||
ggml_backend_sched_t get_sched() const;
|
||||
|
||||
ggml_context * get_ctx_compute() const;
|
||||
const llama_model & get_model() const;
|
||||
|
||||
uint32_t n_ctx() const;
|
||||
uint32_t n_ctx_per_seq() const;
|
||||
@ -46,12 +38,10 @@ struct llama_context {
|
||||
uint32_t n_threads() const;
|
||||
uint32_t n_threads_batch() const;
|
||||
|
||||
llama_memory_t get_memory() const;
|
||||
llama_kv_cache * get_kv_self();
|
||||
const llama_kv_cache * get_kv_self() const;
|
||||
|
||||
// return true of the KV cache was updated
|
||||
// TODO: remove
|
||||
bool kv_self_update(bool optimize);
|
||||
void kv_self_defrag_sched();
|
||||
void kv_self_update();
|
||||
|
||||
enum llama_pooling_type pooling_type() const;
|
||||
|
||||
@ -92,18 +82,8 @@ struct llama_context {
|
||||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
// process a single ubatch with a specific graph type
|
||||
// if memory_context is provided, it will be applied first to the context's memory
|
||||
// ret contains the status of the graph computation
|
||||
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
||||
llm_graph_result_ptr process_ubatch(
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
llama_memory_context_i * mctx,
|
||||
ggml_status & ret);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
int encode(llama_batch & inp_batch);
|
||||
int decode(llama_batch & inp_batch);
|
||||
|
||||
//
|
||||
// state save/load
|
||||
@ -148,32 +128,6 @@ struct llama_context {
|
||||
llama_perf_context_data perf_get_data() const;
|
||||
void perf_reset();
|
||||
|
||||
//
|
||||
// training
|
||||
//
|
||||
|
||||
void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
|
||||
|
||||
void opt_epoch(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result_train,
|
||||
ggml_opt_result_t result_eval,
|
||||
int64_t idata_split,
|
||||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
void opt_epoch_iter(
|
||||
ggml_opt_dataset_t dataset,
|
||||
ggml_opt_result_t result,
|
||||
const std::vector<llama_token> & tokens,
|
||||
const std::vector<llama_token> & labels_sparse,
|
||||
llama_batch & batch,
|
||||
ggml_opt_epoch_callback callback,
|
||||
bool train,
|
||||
int64_t idata_in_loop,
|
||||
int64_t ndata_in_loop,
|
||||
int64_t t_loop_start);
|
||||
|
||||
private:
|
||||
//
|
||||
// output
|
||||
@ -181,34 +135,51 @@ private:
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
uint32_t output_reserve(int32_t n_outputs);
|
||||
int32_t output_reserve(int32_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// TODO: maybe remove this
|
||||
void output_reorder();
|
||||
|
||||
//
|
||||
// graph
|
||||
//
|
||||
|
||||
public:
|
||||
int32_t graph_max_nodes() const;
|
||||
|
||||
// zero-out inputs and create the ctx_compute for the compute graph
|
||||
ggml_cgraph * graph_init();
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||
|
||||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||
|
||||
private:
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_context_i * mctx);
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
ggml_status graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched);
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
// used by kv_self_update()
|
||||
ggml_tensor * build_rope_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_kv_self_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_kv_self_defrag(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
size_t state_read_data (llama_io_read_i & io);
|
||||
@ -225,13 +196,14 @@ private:
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_adapter_loras loras;
|
||||
llama_sbatch sbatch;
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_self;
|
||||
|
||||
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||
bool memory_force_optimize = false;
|
||||
// TODO: remove
|
||||
bool logits_all = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
@ -246,10 +218,8 @@ private:
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
// reuse the batch_allocr to avoid unnecessary memory allocations
|
||||
std::unique_ptr<llama_batch_allocr> balloc;
|
||||
|
||||
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
|
||||
@ -260,9 +230,6 @@ private:
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
||||
// training
|
||||
ggml_opt_context_t opt_ctx = nullptr;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
|
||||
|
@ -1,5 +1 @@
|
||||
#include "llama-cparams.h"
|
||||
|
||||
size_t llama_max_parallel_sequences(void) {
|
||||
return LLAMA_MAX_SEQ;
|
||||
}
|
||||
|
@ -4,8 +4,6 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define LLAMA_MAX_SEQ 64
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
@ -32,7 +30,6 @@ struct llama_cparams {
|
||||
bool flash_attn;
|
||||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
@ -1177,18 +1177,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||
grammar.awaiting_trigger = false;
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||
// get from the first match to the end of the string
|
||||
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,12 +17,8 @@ struct ggml_tensor;
|
||||
struct llama_ubatch;
|
||||
struct llama_cparams;
|
||||
|
||||
struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_unified_context;
|
||||
class llama_kv_cache_unified_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@ -37,8 +33,6 @@ enum llm_ffn_op_type {
|
||||
LLM_FFN_RELU,
|
||||
LLM_FFN_RELU_SQR,
|
||||
LLM_FFN_SWIGLU,
|
||||
LLM_FFN_GEGLU,
|
||||
LLM_FFN_REGLU,
|
||||
};
|
||||
|
||||
enum llm_ffn_gate_type {
|
||||
@ -96,14 +90,14 @@ public:
|
||||
|
||||
class llm_graph_input_pos : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
||||
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
||||
virtual ~llm_graph_input_pos() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
||||
|
||||
const uint32_t n_pos_per_embd = 1;
|
||||
const int64_t n_pos_per_embd = 1;
|
||||
};
|
||||
|
||||
// temperature tuning, used by llama4
|
||||
@ -137,7 +131,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@ -145,8 +139,7 @@ public:
|
||||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
@ -191,16 +184,28 @@ public:
|
||||
const llama_cparams & cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_rs : public llm_graph_input_i {
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
||||
virtual ~llm_graph_input_rs() = default;
|
||||
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_memory_recurrent_context * mctx;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
@ -240,40 +245,15 @@ public:
|
||||
llm_graph_input_attn_kv_unified(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_context * mctx) :
|
||||
const llama_kv_cache_unified * kv_self) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
kv_self(kv_self) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
@ -285,7 +265,7 @@ public:
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa_context * mctx;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
@ -303,44 +283,6 @@ public:
|
||||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
// TODO: remove this when ggml_scale_add is implemented
|
||||
class llm_graph_input_one : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_one() {}
|
||||
virtual ~llm_graph_input_one() = default;
|
||||
|
||||
void set_input(const llama_ubatch *) override;
|
||||
|
||||
ggml_tensor * one = nullptr; // F32
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
@ -355,7 +297,6 @@ class llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result_i() = default;
|
||||
|
||||
virtual ggml_tensor * get_tokens() = 0;
|
||||
virtual ggml_tensor * get_logits() = 0;
|
||||
virtual ggml_tensor * get_embd() = 0;
|
||||
virtual ggml_tensor * get_embd_pooled() = 0;
|
||||
@ -370,7 +311,6 @@ class llm_graph_result : public llm_graph_result_i {
|
||||
public:
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() override { return t_tokens; }
|
||||
ggml_tensor * get_logits() override { return t_logits; }
|
||||
ggml_tensor * get_embd() override { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
|
||||
@ -387,7 +327,6 @@ public:
|
||||
}
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
@ -411,15 +350,15 @@ struct llm_graph_params {
|
||||
const llama_cparams & cparams;
|
||||
const llama_ubatch & ubatch;
|
||||
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
ggml_backend_sched * sched;
|
||||
ggml_backend * backend_cpu;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_i * memory;
|
||||
const llama_cross * cross;
|
||||
|
||||
uint32_t n_outputs;
|
||||
int32_t n_outputs;
|
||||
|
||||
const llm_graph_cb & cb;
|
||||
};
|
||||
@ -435,6 +374,7 @@ struct llm_graph_context {
|
||||
const int64_t n_layer;
|
||||
const int64_t n_rot;
|
||||
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
|
||||
const int64_t n_ctx_per_seq;
|
||||
const int64_t n_head;
|
||||
const int64_t n_head_kv;
|
||||
const int64_t n_embd_head_k;
|
||||
@ -453,8 +393,8 @@ struct llm_graph_context {
|
||||
const float norm_eps;
|
||||
const float norm_rms_eps;
|
||||
|
||||
const int64_t n_tokens;
|
||||
const int64_t n_outputs;
|
||||
const int32_t n_tokens;
|
||||
const int32_t n_outputs;
|
||||
const int32_t n_ctx_orig; // yarn
|
||||
|
||||
const enum llama_pooling_type pooling_type;
|
||||
@ -462,21 +402,22 @@ struct llm_graph_context {
|
||||
|
||||
ggml_context * ctx0 = nullptr;
|
||||
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_sched * sched;
|
||||
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_i * memory;
|
||||
const llama_cross * cross;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
std::unique_ptr<llm_graph_result> res;
|
||||
|
||||
llm_graph_context(const llm_graph_params & params);
|
||||
virtual ~llm_graph_context() = default;
|
||||
|
||||
int64_t n_pos_per_embd() const;
|
||||
|
||||
void cb(ggml_tensor * cur, const char * name, int il) const;
|
||||
|
||||
@ -548,26 +489,27 @@ struct llm_graph_context {
|
||||
ggml_tensor * build_inp_out_ids() const;
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy() const;
|
||||
ggml_tensor * build_inp_s_mask() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
ggml_tensor * build_inp_pos_bucket_dec() const;
|
||||
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
||||
|
||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||
|
||||
//
|
||||
// attention
|
||||
//
|
||||
|
||||
ggml_tensor * build_attn_mha(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
|
||||
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
|
||||
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
bool v_trans,
|
||||
float kq_scale) const;
|
||||
|
||||
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
|
||||
@ -600,22 +542,6 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||
|
||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
@ -631,62 +557,23 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
//
|
||||
// recurrent
|
||||
//
|
||||
|
||||
// TODO: avoid notion of "kv"
|
||||
// TODO: move this implementation to llama_memory_recurrent.
|
||||
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||
// `llama_memory_recurrent`
|
||||
ggml_tensor * build_rs(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
uint32_t n_kv,
|
||||
uint32_t kv_head,
|
||||
uint32_t kv_size,
|
||||
int32_t rs_zero,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
llm_graph_input_rs * build_rs_inp() const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
ggml_tensor * build_copy_mask_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_load(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_store(
|
||||
@ -705,6 +592,3 @@ struct llm_graph_context {
|
||||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
};
|
||||
|
||||
// TODO: better name
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
|
@ -2,22 +2,6 @@
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa_any() const {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (swa_layers[il]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_head(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_arr[il];
|
||||
@ -65,7 +49,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_r() const {
|
||||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return token_shift_count * n_embd;
|
||||
@ -76,7 +60,7 @@ uint32_t llama_hparams::n_embd_r() const {
|
||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_s() const {
|
||||
uint32_t llama_hparams::n_embd_v_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
@ -86,17 +70,9 @@ uint32_t llama_hparams::n_embd_s() const {
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_pos_per_embd() const {
|
||||
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return swa_layers[il];
|
||||
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
|
@ -14,12 +14,6 @@ enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
};
|
||||
|
||||
enum llama_swa_type {
|
||||
LLAMA_SWA_TYPE_NONE = 0,
|
||||
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
uint32_t n_embd;
|
||||
uint32_t n_layer;
|
||||
@ -41,6 +35,8 @@ struct llama_hparams {
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_swa = 0; // sliding window attention (SWA)
|
||||
uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
|
||||
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
|
||||
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
|
||||
uint32_t n_expert = 0;
|
||||
@ -100,24 +96,12 @@ struct llama_hparams {
|
||||
|
||||
std::array<int, 4> rope_sections;
|
||||
|
||||
// Sliding Window Attention (SWA)
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
// the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa = 0;
|
||||
// if swa_layers[il] == true, then layer il is SWA
|
||||
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||
// by default, all layers are dense
|
||||
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
uint32_t ssm_d_inner = 0;
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
@ -132,23 +116,15 @@ struct llama_hparams {
|
||||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
bool use_kq_norm = true;
|
||||
|
||||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
||||
// llama4
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
bool use_kq_norm = true;
|
||||
uint32_t n_attn_chunk = 0;
|
||||
// values below seems to be fixed on llama4
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
float f_attn_temp_scale = 0.1;
|
||||
|
||||
// gemma3n altup
|
||||
uint32_t n_altup = 4; // altup_num_inputs
|
||||
uint32_t i_altup_act = 0; // altup_active_idx
|
||||
uint32_t laurel_rank = 64;
|
||||
uint32_t n_embd_altup = 256;
|
||||
|
||||
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
||||
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
||||
@ -157,23 +133,6 @@ struct llama_hparams {
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
||||
// note that if n_pattern == 0, all layers are SWA
|
||||
// if n_pattern == 1, all layers are dense
|
||||
// example: n_pattern = 3
|
||||
// il == 0: swa
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
// il == 3: swa
|
||||
// il == 4: swa
|
||||
// il == 5: dense
|
||||
// il == 6: swa
|
||||
// etc ...
|
||||
void set_swa_pattern(uint32_t n_pattern);
|
||||
|
||||
// return true if one of the layers is SWA
|
||||
bool is_swa_any() const;
|
||||
|
||||
uint32_t n_head(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_head_kv(uint32_t il = 0) const;
|
||||
@ -190,15 +149,10 @@ struct llama_hparams {
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_r() const;
|
||||
uint32_t n_embd_k_s() const;
|
||||
|
||||
// dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_s() const;
|
||||
|
||||
// whether or not the given layer is recurrent (for hybrid models)
|
||||
bool is_recurrent(uint32_t il) const;
|
||||
|
||||
uint32_t n_pos_per_embd() const;
|
||||
uint32_t n_embd_v_s() const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
};
|
||||
|
@ -1,279 +0,0 @@
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
|
||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
|
||||
size_swa = size_base;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::clear(bool data) {
|
||||
kv_base->clear(data);
|
||||
kv_swa ->clear(data);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool res = true;
|
||||
|
||||
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
||||
res = res & kv_swa ->seq_rm(seq_id, p0, p1);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
kv_base->seq_keep(seq_id);
|
||||
kv_swa ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
kv_base->seq_div(seq_id, p0, p1, d);
|
||||
kv_swa ->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the base cache is a superset of the SWA cache, so we can just check the SWA cache
|
||||
return kv_swa->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_simple(n_ubatch);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// if it fails, try equal split
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
if (heads_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto heads_swa = kv_swa->prepare(ubatches);
|
||||
if (heads_swa.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
return kv_base->get_size() == kv_swa->get_size();
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
kv_base->state_write(io, seq_id);
|
||||
kv_swa ->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
kv_base->state_read(io, seq_id);
|
||||
kv_swa ->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
|
||||
return kv_base.get();
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
return kv_swa.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv) :
|
||||
ctx_base(kv->get_base()->init_full()),
|
||||
ctx_swa (kv->get_swa ()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_iswa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_base->next();
|
||||
ctx_swa ->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & ctx_base->apply();
|
||||
res = res & ctx_swa ->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
||||
}
|
@ -1,128 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa
|
||||
//
|
||||
|
||||
// utilizes two instances of llama_kv_cache_unified
|
||||
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
|
||||
|
||||
class llama_kv_cache_unified_iswa : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_unified_iswa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
~llama_kv_cache_unified_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_base() const;
|
||||
llama_kv_cache_unified * get_swa () const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_context specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_base() const;
|
||||
const llama_kv_cache_unified_context * get_swa() const;
|
||||
|
||||
private:
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_context_ptr ctx_base;
|
||||
const llama_memory_context_ptr ctx_swa;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
File diff suppressed because it is too large
Load Diff
@ -1,303 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cells.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified
|
||||
//
|
||||
|
||||
class llama_kv_cache_unified : public llama_memory_i {
|
||||
public:
|
||||
static uint32_t get_padding(const llama_cparams & cparams);
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
using ubatch_heads = std::vector<uint32_t>;
|
||||
|
||||
struct defrag_info {
|
||||
bool empty() const {
|
||||
return ids.empty();
|
||||
}
|
||||
|
||||
// contains information about which cell moves where:
|
||||
// - cell i moves to ids[i]
|
||||
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
||||
~llama_kv_cache_unified() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the head locations
|
||||
// return empty vector on failure
|
||||
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
|
||||
// return the cell position where we can insert the ubatch
|
||||
// return -1 on failure to find a contiguous slot of kv cells
|
||||
int32_t find_slot(const llama_ubatch & ubatch) const;
|
||||
|
||||
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
|
||||
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
//
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
uint32_t il;
|
||||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
||||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
int debug = 0;
|
||||
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// return non-empty vector if cells have been moved
|
||||
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const;
|
||||
|
||||
llm_graph_result_ptr build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_context specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
llama_kv_cache_unified * kv;
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
// update context
|
||||
//
|
||||
|
||||
bool do_shift = false;
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
//
|
||||
// batch processing context
|
||||
//
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
ubatch_heads heads;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
//
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
// the beginning of the current slot in which the ubatch will be inserted
|
||||
int32_t head;
|
||||
};
|
1380
examples/talk-llama/llama-kv-cache.cpp
Normal file
1380
examples/talk-llama/llama-kv-cache.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -4,41 +4,210 @@
|
||||
#include "llama-io.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <functional>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_ubatch;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
virtual ~llama_kv_cache() = default;
|
||||
using llama_memory_i::llama_memory_i;
|
||||
|
||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||
// return a state object containing the ubatches and KV cache state required to process them
|
||||
// check the llama_memory_state_i::get_status() for the result
|
||||
virtual llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) = 0;
|
||||
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
||||
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
virtual llama_memory_state_ptr init_full() = 0;
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
|
||||
// process any pending defrag/shift/etc. operations
|
||||
// optionally call once before processing a new batch
|
||||
// return true if any operations were performed
|
||||
virtual bool update(llama_context & lctx) = 0;
|
||||
|
||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
||||
// TODO: change to
|
||||
// llama_memory_state_ptr init_defrag(float thold) = 0;
|
||||
//
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// getters
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
bool get_can_edit() const override { return get_can_shift(); }
|
||||
|
||||
//
|
||||
// state write/read
|
||||
//
|
||||
|
||||
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
|
||||
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
|
||||
};
|
||||
|
||||
struct llama_kv_cache_guard {
|
||||
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
||||
|
||||
~llama_kv_cache_guard() {
|
||||
kv->restore();
|
||||
}
|
||||
|
||||
void commit() {
|
||||
kv->commit();
|
||||
}
|
||||
|
||||
private:
|
||||
llama_kv_cache * kv;
|
||||
};
|
||||
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
int32_t src = -1; // used by recurrent state models to copy states
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const llama_kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
// ring-buffer of cached KV data
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
// can be used to query data from the model if needed
|
||||
struct callbacks {
|
||||
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
|
||||
};
|
||||
|
||||
llama_kv_cache_unified(
|
||||
const llama_hparams & hparams,
|
||||
callbacks cbs);
|
||||
|
||||
virtual ~llama_kv_cache_unified() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
uint32_t kv_size,
|
||||
bool offload);
|
||||
|
||||
int32_t get_n_tokens() const override;
|
||||
int32_t get_used_cells() const override;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos pos_max() const;
|
||||
|
||||
void clear() override;
|
||||
void defrag() override;
|
||||
|
||||
virtual void restore() override;
|
||||
virtual void commit() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
// Note: On success, it's important that cache.head points
|
||||
// to the first cell of the slot.
|
||||
bool find_slot(const llama_ubatch & batch);
|
||||
|
||||
// TODO: maybe not needed
|
||||
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||
|
||||
// find how many cells are currently in use
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
// defrag
|
||||
|
||||
struct {
|
||||
std::vector<uint32_t> ids;
|
||||
} defrag_info;
|
||||
|
||||
// return true if cells have been moved
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// commit/restore cache
|
||||
|
||||
struct slot_range {
|
||||
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||
uint32_t c1 = 0;
|
||||
};
|
||||
|
||||
// pending cell updates that are not yet committed
|
||||
struct {
|
||||
std::vector<slot_range> ranges;
|
||||
} pending;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||
|
||||
// members
|
||||
|
||||
const llama_hparams & hparams;
|
||||
|
||||
callbacks cbs;
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
// TODO: remove this and implement llama_kv_cache_recurrent instead
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||
//public:
|
||||
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||
//};
|
||||
|
||||
//
|
||||
// kv cache view
|
||||
//
|
||||
|
||||
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
|
||||
|
||||
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
|
||||
|
@ -1,439 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include <bitset>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
class llama_kv_cells_unified {
|
||||
public:
|
||||
void reset() {
|
||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
seq[i].reset();
|
||||
}
|
||||
|
||||
has_shift = false;
|
||||
|
||||
used.clear();
|
||||
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
seq_pos[s].clear();
|
||||
}
|
||||
}
|
||||
|
||||
void reset_shift() {
|
||||
has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < shift.size(); ++i) {
|
||||
shift[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t size() const {
|
||||
return pos.size();
|
||||
}
|
||||
|
||||
void resize(uint32_t n) {
|
||||
pos.resize(n);
|
||||
shift.resize(n);
|
||||
seq.resize(n);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
bool is_empty(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
|
||||
|
||||
return pos[i] == -1;
|
||||
}
|
||||
|
||||
uint32_t get_used() const {
|
||||
return used.size();
|
||||
}
|
||||
|
||||
// the index of the first cell that is used
|
||||
// return 0 if no cells are used
|
||||
uint32_t used_min() const {
|
||||
return used.empty() ? 0 : *used.begin();
|
||||
}
|
||||
|
||||
// the index of the last cell that is used + 1
|
||||
// return 0 if no cells are used
|
||||
uint32_t used_max_p1() const {
|
||||
return used.empty() ? 0 : *used.rbegin() + 1;
|
||||
}
|
||||
|
||||
bool get_has_shift() const {
|
||||
return has_shift;
|
||||
}
|
||||
|
||||
// move cell isrc to idst (used during defrag)
|
||||
void mv(uint32_t isrc, uint32_t idst) {
|
||||
assert(isrc < pos.size());
|
||||
assert(idst < pos.size());
|
||||
|
||||
assert(pos[idst] == -1);
|
||||
assert(pos[isrc] != -1);
|
||||
|
||||
pos [idst] = pos [isrc];
|
||||
shift[idst] = shift[isrc];
|
||||
seq [idst] = seq [isrc];
|
||||
|
||||
pos [isrc] = -1;
|
||||
shift[isrc] = 0;
|
||||
seq [isrc].reset();
|
||||
|
||||
used.erase (isrc);
|
||||
used.insert(idst);
|
||||
}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||
assert(i + n <= pos.size());
|
||||
|
||||
llama_kv_cells_unified res;
|
||||
|
||||
res.resize(n);
|
||||
|
||||
for (uint32_t j = 0; j < n; ++j) {
|
||||
res.pos[j] = pos[i + j];
|
||||
res.seq[j] = seq[i + j];
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||
used.insert(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||
used.erase(i + j);
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
seq_pos_rm(i + j);
|
||||
}
|
||||
|
||||
pos[i + j] = other.pos[j];
|
||||
seq[i + j] = other.seq[j];
|
||||
|
||||
if (pos[i + j] != -1) {
|
||||
seq_pos_add(i + j);
|
||||
}
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// clear a non-empty cell
|
||||
void rm(uint32_t i) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
seq_pos_rm(i);
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
}
|
||||
|
||||
// note: call only if the cell has seq_id
|
||||
// return true if the cell becomes empty
|
||||
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
assert(seq[i].test(seq_id));
|
||||
assert(pos[i] != -1);
|
||||
assert(seq_id >= 0);
|
||||
|
||||
seq[i].reset(seq_id);
|
||||
seq_pos_dec(seq_id, pos[i]);
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
|
||||
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
|
||||
if (seq[i].test(seq_id)) {
|
||||
seq_pos_rm(i);
|
||||
seq[i].reset();
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos_inc(seq_id, pos[i]);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (seq[i].any()) {
|
||||
seq_pos_rm(i);
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
assert(pos[i] == -1);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// number of different sequences in the cell
|
||||
int seq_count(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return seq[i].count();
|
||||
}
|
||||
|
||||
// check if the cell contains seq_id
|
||||
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
||||
assert(i < pos.size());
|
||||
assert(seq_id >= 0);
|
||||
|
||||
return seq[i].test(seq_id);
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty and the seq_id is not in the cell
|
||||
void seq_add(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
assert(!seq[i].test(seq_id));
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos_inc(seq_id, pos[i]);
|
||||
}
|
||||
|
||||
// return the sequence id of this cell
|
||||
// note: call only for cells with exactly one sequence
|
||||
llama_seq_id seq_get(uint32_t i) const {
|
||||
assert(seq[i].count() == 1);
|
||||
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
// the minimum position of sequence seq_id currently present in any of the cells
|
||||
// return -1 if the sequence is not present
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const {
|
||||
assert(seq_id >= 0);
|
||||
assert(seq_id < LLAMA_MAX_SEQ);
|
||||
|
||||
if (seq_pos[seq_id].empty()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
assert(seq_pos[seq_id].begin()->second > 0);
|
||||
|
||||
return seq_pos[seq_id].begin()->first;
|
||||
}
|
||||
|
||||
// the maximum position of sequence seq_id currently present in any of the cells
|
||||
// return -1 if the sequence is not present
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const {
|
||||
assert(seq_id >= 0);
|
||||
assert(seq_id < LLAMA_MAX_SEQ);
|
||||
|
||||
if (seq_pos[seq_id].empty()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
assert(seq_pos[seq_id].rbegin()->second > 0);
|
||||
|
||||
return seq_pos[seq_id].rbegin()->first;
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos pos_get(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return pos[i];
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos get_shift(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return shift[i];
|
||||
}
|
||||
|
||||
// check if a cell is not empty and its position is within [p0, p1)
|
||||
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
|
||||
assert(i < pos.size());
|
||||
|
||||
return pos[i] >= p0 && pos[i] < p1;
|
||||
}
|
||||
|
||||
// set the position of an empty cell
|
||||
// does not modify "has_shift"
|
||||
// note: call only if the cell is empty
|
||||
void pos_set(uint32_t i, llama_pos p) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] == -1);
|
||||
assert(seq[i].none());
|
||||
|
||||
pos[i] = p;
|
||||
|
||||
used.insert(i);
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] + d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
bool pos_add(uint32_t i, llama_pos d) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
seq_pos_rm(i);
|
||||
|
||||
pos[i] += d;
|
||||
shift[i] += d;
|
||||
|
||||
has_shift = true;
|
||||
|
||||
if (pos[i] < 0) {
|
||||
seq[i].reset();
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
seq_pos_add(i);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] / d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
void pos_div(uint32_t i, int d) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
const llama_pos p_old = pos[i];
|
||||
|
||||
seq_pos_rm(i);
|
||||
|
||||
pos[i] /= d;
|
||||
shift[i] += p_old - pos[i];
|
||||
|
||||
seq_pos_add(i);
|
||||
|
||||
has_shift = true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool has_shift = false;
|
||||
|
||||
// set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||
std::set<uint32_t> used;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||
//
|
||||
// cells.pos_add(x, shift_x);
|
||||
// cells.pos_div(y, shift_y);
|
||||
// ...
|
||||
//
|
||||
// if (cells.has_shift()) {
|
||||
// for (int i = 0; i < n; ++i) {
|
||||
// auto shift_i = cells.get_shift(i);
|
||||
// ...
|
||||
// }
|
||||
// cells.reset_shift();
|
||||
// }
|
||||
//
|
||||
std::vector<llama_pos> shift;
|
||||
|
||||
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
||||
|
||||
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||
std::vector<seq_set_t> seq;
|
||||
|
||||
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
||||
// if the position p is not present, seq_pos[s][p] is not set
|
||||
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||
//
|
||||
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
||||
// - during performing a cache reuse via (rm + add)
|
||||
// - some vision models have input embeddings with repeating positions
|
||||
//
|
||||
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
||||
|
||||
// helper functions for updating `seq_pos`, once cell at a time:
|
||||
|
||||
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
||||
auto it = seq_pos[s].find(p);
|
||||
assert(it != seq_pos[s].end());
|
||||
|
||||
if (--it->second == 0) {
|
||||
seq_pos[s].erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
||||
seq_pos[s][p]++;
|
||||
}
|
||||
|
||||
// remove cell i
|
||||
void seq_pos_rm(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos_dec(s, pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add cell i
|
||||
void seq_pos_add(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos_inc(s, pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
@ -1,246 +0,0 @@
|
||||
#include "llama-memory-hybrid.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
//
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
llama_memory_hybrid::llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_unified(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max
|
||||
)) {}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
// follow the recurrent pattern for creating the ubatch splits
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
// prepare the attention cache
|
||||
auto heads_attn = mem_attn->prepare(ubatches);
|
||||
if (heads_attn.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_context>(
|
||||
this, std::move(heads_attn), std::move(ubatches));
|
||||
} while(false);
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
||||
return std::make_unique<llama_memory_hybrid_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid::get_can_shift() const {
|
||||
// Shifting is trivially supported for recurrent
|
||||
return mem_attn->get_can_shift();
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::clear(bool data) {
|
||||
mem_attn->clear(data);
|
||||
mem_recr->clear(data);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// Try removing from the recurrent cache first since it may fail. If it does
|
||||
// fail, the cache will not have been mutated.
|
||||
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||
return false;
|
||||
}
|
||||
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
||||
mem_attn->seq_keep(seq_id);
|
||||
mem_recr->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the min of the total cache is the max of the two caches' min values
|
||||
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
||||
// the max of the total cache is the min of the two caches' max values
|
||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
mem_attn->state_write(io, seq_id);
|
||||
mem_recr->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
mem_attn->state_read(io, seq_id);
|
||||
mem_recr->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
||||
return mem_recr.get();
|
||||
}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
||||
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_attn->next();
|
||||
ctx_recr->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & ctx_attn->apply();
|
||||
res = res & ctx_recr->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
||||
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
||||
// support models where each layer may be either attention-based or recurrent
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn = nullptr,
|
||||
layer_filter_cb && filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||
public:
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||
|
||||
// init full
|
||||
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
|
||||
|
||||
// init update
|
||||
explicit llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// init success
|
||||
llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_context() = default;
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_context
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_attn() const;
|
||||
const llama_memory_recurrent_context * get_recr() const;
|
||||
|
||||
private:
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_context_ptr ctx_attn;
|
||||
const llama_memory_context_ptr ctx_recr;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user